Browse Source

Migrated project

pull/2/head
zbian 3 years ago
parent
commit
404ecbdcc6
  1. 144
      .gitignore
  2. 4
      MANIFEST.in
  3. 104
      README.md
  4. 4
      colossalai/__init__.py
  5. 2
      colossalai/builder/__init__.py
  6. 262
      colossalai/builder/builder.py
  7. 226
      colossalai/builder/pipeline.py
  8. 215
      colossalai/checkpointing.py
  9. 14
      colossalai/communication/__init__.py
  10. 84
      colossalai/communication/collective.py
  11. 333
      colossalai/communication/p2p.py
  12. 54
      colossalai/communication/ring.py
  13. 73
      colossalai/communication/utils.py
  14. 31
      colossalai/constants.py
  15. 5
      colossalai/context/__init__.py
  16. 70
      colossalai/context/_utils.py
  17. 99
      colossalai/context/config.py
  18. 454
      colossalai/context/parallel_context.py
  19. 44
      colossalai/context/parallel_mode.py
  20. 15
      colossalai/context/process_group_initializer/__init__.py
  21. 44
      colossalai/context/process_group_initializer/initializer_1d.py
  22. 123
      colossalai/context/process_group_initializer/initializer_2d.py
  23. 255
      colossalai/context/process_group_initializer/initializer_2p5d.py
  24. 172
      colossalai/context/process_group_initializer/initializer_3d.py
  25. 41
      colossalai/context/process_group_initializer/initializer_data.py
  26. 63
      colossalai/context/process_group_initializer/initializer_pipeline.py
  27. 27
      colossalai/context/process_group_initializer/initializer_sequence.py
  28. 41
      colossalai/context/process_group_initializer/initializer_tensor.py
  29. 30
      colossalai/context/process_group_initializer/process_group_initializer.py
  30. 8
      colossalai/context/random/__init__.py
  31. 144
      colossalai/context/random/_helper.py
  32. 74
      colossalai/context/random/seed_manager.py
  33. 16
      colossalai/core.py
  34. 7
      colossalai/engine/__init__.py
  35. 170
      colossalai/engine/_base_engine.py
  36. 10
      colossalai/engine/amp_type.py
  37. 5
      colossalai/engine/gradient_handler/__init__.py
  38. 25
      colossalai/engine/gradient_handler/_base_gradient_handler.py
  39. 48
      colossalai/engine/gradient_handler/_data_parallel_gradient_handler.py
  40. 16
      colossalai/engine/gradient_handler/_zero_gradient_handler.py
  41. 5
      colossalai/engine/schedule/__init__.py
  42. 129
      colossalai/engine/schedule/_base_schedule.py
  43. 185
      colossalai/engine/schedule/_no_pipeline.py
  44. 316
      colossalai/engine/schedule/_pipeline.py
  45. 16
      colossalai/engine/schedule/_utils.py
  46. 371
      colossalai/initialize.py
  47. 26
      colossalai/logging/__init__.py
  48. 97
      colossalai/logging/logging.py
  49. 6
      colossalai/nn/__init__.py
  50. 3
      colossalai/nn/data/__init__.py
  51. 14
      colossalai/nn/data/_utils.py
  52. 17
      colossalai/nn/data/base_dataset.py
  53. 43
      colossalai/nn/data/caltech101_dataset.py
  54. 44
      colossalai/nn/data/cifar10_dataset.py
  55. 4
      colossalai/nn/data/sampler/__init__.py
  56. 19
      colossalai/nn/data/sampler/base_sampler.py
  57. 102
      colossalai/nn/data/sampler/data_parallel_sampler.py
  58. 9
      colossalai/nn/layer/__init__.py
  59. 63
      colossalai/nn/layer/_common_utils.py
  60. 138
      colossalai/nn/layer/_parallel_utilities.py
  61. 27
      colossalai/nn/layer/base_layer.py
  62. 5
      colossalai/nn/layer/parallel_1d/__init__.py
  63. 15
      colossalai/nn/layer/parallel_1d/_utils.py
  64. 166
      colossalai/nn/layer/parallel_1d/layers.py
  65. 11
      colossalai/nn/layer/parallel_2d/__init__.py
  66. 522
      colossalai/nn/layer/parallel_2d/_operation.py
  67. 220
      colossalai/nn/layer/parallel_2d/_transformer.py
  68. 23
      colossalai/nn/layer/parallel_2d/_utils.py
  69. 391
      colossalai/nn/layer/parallel_2d/_vit.py
  70. 258
      colossalai/nn/layer/parallel_2d/layers.py
  71. 13
      colossalai/nn/layer/parallel_2p5d/__init__.py
  72. 535
      colossalai/nn/layer/parallel_2p5d/_operation.py
  73. 206
      colossalai/nn/layer/parallel_2p5d/_transformer.py
  74. 25
      colossalai/nn/layer/parallel_2p5d/_utils.py
  75. 351
      colossalai/nn/layer/parallel_2p5d/_vit.py
  76. 266
      colossalai/nn/layer/parallel_2p5d/layers.py
  77. 9
      colossalai/nn/layer/parallel_3d/__init__.py
  78. 349
      colossalai/nn/layer/parallel_3d/_operation.py
  79. 49
      colossalai/nn/layer/parallel_3d/_utils.py
  80. 368
      colossalai/nn/layer/parallel_3d/_vit.py
  81. 172
      colossalai/nn/layer/parallel_3d/layers.py
  82. 4
      colossalai/nn/layer/parallel_sequence/__init__.py
  83. 169
      colossalai/nn/layer/parallel_sequence/_operation.py
  84. 15
      colossalai/nn/layer/parallel_sequence/_utils.py
  85. 188
      colossalai/nn/layer/parallel_sequence/layers.py
  86. 3
      colossalai/nn/layer/parallel_vision_transformer/__init__.py
  87. 59
      colossalai/nn/layer/parallel_vision_transformer/layers.py
  88. 5
      colossalai/nn/layer/vanilla_resnet/__init__.py
  89. 64
      colossalai/nn/layer/vanilla_resnet/basic_block.py
  90. 69
      colossalai/nn/layer/vanilla_resnet/bottleneck.py
  91. 15
      colossalai/nn/layer/vanilla_resnet/conv.py
  92. 63
      colossalai/nn/layer/vanilla_resnet/reslayer.py
  93. 7
      colossalai/nn/layer/vanilla_vision_transformer/__init__.py
  94. 244
      colossalai/nn/layer/vanilla_vision_transformer/layers.py
  95. 3
      colossalai/nn/layer/wrapper/__init__.py
  96. 37
      colossalai/nn/layer/wrapper/lambda_wrapper.py
  97. 6
      colossalai/nn/loss/__init__.py
  98. 13
      colossalai/nn/loss/base_loss.py
  99. 120
      colossalai/nn/loss/cross_entropy_1d.py
  100. 128
      colossalai/nn/loss/cross_entropy_2d.py
  101. Some files were not shown because too many files have changed in this diff Show More

144
.gitignore vendored

@ -0,0 +1,144 @@
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
docs/.build/
# PyBuilder
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
.python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# IDE
.idea/
.vscode/
# macos
.DS_Store
#data/
# launcher setting
tests/launcher/log
tests/launcher/personal
docs/.build

4
MANIFEST.in

@ -0,0 +1,4 @@
include *.txt README.md
recursive-include requirements *.txt
recursive-include colossalai *.cpp *.h *.cu *.tr *.cuh *.cc
recursive-include csrc *.cpp *.h *.cu *.tr *.cuh *.cc

104
README.md

@ -0,0 +1,104 @@
# ColossalAI
An integrated large-scale model training framework with efficient parallelization techniques
## Installation
### PyPI
```bash
pip install colossalai
```
### Install From Source
```shell
git clone git@github.com:hpcaitech/ColossalAI.git
cd ColossalAI
# install dependency
pip install -r requirements/requirements.txt
# install colossalai
pip install .
```
Install and enable CUDA kernel fusion (compulsory installation when using fused optimizer)
```shell
pip install -v --no-cache-dir --global-option="--cuda_ext" .
```
## Documentation
- [Documentation](https://www.colossalai.org/)
## Quick View
### Start Distributed Training in Lines
```python
import colossalai
from colossalai.engine import Engine
from colossalai.trainer import Trainer
from colossalai.core import global_context as gpc
model, train_dataloader, test_dataloader, criterion, optimizer, schedule, lr_scheduler = colossalai.initialize()
engine = Engine(
model=model,
criterion=criterion,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
schedule=schedule
)
trainer = Trainer(engine=engine,
hooks_cfg=gpc.config.hooks,
verbose=True)
trainer.fit(
train_dataloader=train_dataloader,
test_dataloader=test_dataloader,
max_epochs=gpc.config.num_epochs,
display_progress=True,
test_interval=5
)
```
### Write a Simple 2D Parallel Model
Let's say we have a huge MLP model and its very large hidden size makes it difficult to fit into a single GPU. We can
then distribute the model weights across GPUs in a 2D mesh while you still write your model in a familiar way.
```python
from colossalai.nn import Linear2D
import torch.nn as nn
class MLP_2D(nn.Module):
def __init__(self):
super().__init__()
self.linear_1 = Linear2D(in_features=1024, out_features=16384)
self.linear_2 = Linear2D(in_features=16384, out_features=1024)
def forward(self, x):
x = self.linear_1(x)
x = self.linear_2(x)
return x
```
## Features
ColossalAI provides a collection of parallel training components for you. We aim to support you to write your
distributed deep learning models just like how you write your single-GPU model. We provide friendly tools to kickstart
distributed training in a few lines.
- [Data Parallelism](./docs/parallelization.md)
- [Pipeline Parallelism](./docs/parallelization.md)
- [1D, 2D, 2.5D, 3D and sequence parallelism](./docs/parallelization.md)
- [friendly trainer and engine](./docs/trainer_engine.md)
- [Extensible for new parallelism](./docs/add_your_parallel.md)
- [Mixed Precision Training](./docs/amp.md)
- [Zero Redundancy Optimizer (ZeRO)](./docs/zero.md)

4
colossalai/__init__.py

@ -0,0 +1,4 @@
from .initialize import init_dist, initialize
from .nn import *
__version__ = '0.0.1'

2
colossalai/builder/__init__.py

@ -0,0 +1,2 @@
from .builder import *
from .pipeline import ModelInitializer

262
colossalai/builder/builder.py

@ -0,0 +1,262 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import inspect
from collections.abc import Iterable
from colossalai.registry import *
def build_from_config(module, config: dict):
"""Returns an object of :class:`module` constructed from `config`.
:param module: A python or user-defined class
:type module: class
:param config: A python dict containing information used in the construction
of the return object
:type config: dict
:raises AssertionError: Raises an AssertionError if `module` is not a class
:return: An object of :class:`module`
:rtype: :class:`module`
"""
assert inspect.isclass(module), 'module must be a class'
return module(**config)
def build_from_registry(config, registry: Registry):
"""Returns an object constructed from `config`, the type of the object
is specified by `registry`.
:param config: A python dict or a :class:`colossalai.context.Config` object
containing information used in the construction of the return object
:type config: dict or :class:`colossalai.context.colossalai.context.Config`
:param registry: A registry specifying the type of the return object
:type registry: :class:`Registry`
:raises AssertionError: Raises an AssertionError if `registry` is not an object
of :class:`Registry` or `mod_type` in `config` is not found in `registry`
:raises Exception: Raises an Exception if an error occurred when building
from registry
:return: An object specified by `registry`
:rtype: Python object specified by `registry`
"""
config_ = config.copy() # keep the original config untouched
assert isinstance(
registry, Registry), f'Expected type Registry but got {type(registry)}'
mod_type = config_.pop('type')
assert registry.has(
mod_type), f'{mod_type} is not found in registry {registry.name}'
try:
obj = registry.get_module(mod_type)(**config_)
except Exception as e:
print(
f'An error occurred when building {mod_type} from registry {registry.name}', flush=True)
raise e
return obj
def build_layer(config):
"""Returns a layer object of :class:`nn.Module` constructed from `config`.
:param config: A python dict or a :class:`colossalai.context.Config` object
containing information used in the construction of the return object
:type config: dict or :class:`colossalai.context.Config`
:return: An object of :class:`nn.Module`
:rtype: :class:`nn.Module`
"""
return build_from_registry(config, LAYERS)
def build_loss(config):
"""Returns a loss function object of :class:`torch.autograd.Function` constructed
from `config`.
:param config: A python dict or a :class:`colossalai.context.Config` object
containing information used in the construction of the return object
:type config: dict or :class:`colossalai.context.Config`
:return: An object of :class:`torch.autograd.Function`
:rtype: :class:`torch.autograd.Function`
"""
return build_from_registry(config, LOSSES)
def build_model(config):
"""Returns a model object of :class:`nn.Module` constructed from `config`.
:param config: A python dict or a :class:`colossalai.context.Config` object
containing information used in the construction of the return object
:type config: dict or :class:`colossalai.context.Config`
:return: An object of :class:`nn.Module`
:rtype: :class:`nn.Module`
"""
return build_from_registry(config, MODELS)
def build_dataset(config):
"""Returns a dataset object of :class:`torch.utils.data.Dataset` constructed
from `config`.
:param config: A python dict or a :class:`colossalai.context.Config` object
containing information used in the construction of the return object
:type config: dict or :class:`colossalai.context.Config`
:return: An object of :class:`torch.utils.data.Dataset`
:rtype: :class:`torch.utils.data.Dataset`
"""
return build_from_registry(config, DATASETS)
def build_optimizer(config, model, params: Iterable = None, need_module=False):
"""Returns an optimizer object of :class:`torch.optim.Optimizer` constructed from `config`,
'model' and 'params'.
:param config: A python dict or a :class:`colossalai.context.Config` object
containing information used in the construction of the return object
:type config: dict or :class:`colossalai.context.Config`
:param model: A model containing parameters for the optimizer
:type model: :class:`nn.Module`
:param params: A dict containing parameters for the optimizer
:type params: dict, optional
:param need_module: Indicates whether the optimizer needs a module
:type params: bool, optional
:raises AssertionError: Raises an AssertionError if both `model` and `params` are None
:return: An object of :class:`torch.optim.Optimizer`
:rtype: :class:`torch.optim.Optimizer`
"""
assert model is not None or params is not None, 'arguments model and params can not both be None'
if need_module:
config['module'] = model
elif model is not None:
config['params'] = model.parameters()
elif params is not None:
config['params'] = params
return build_from_registry(config, OPTIMIZERS)
def build_gradient_handler(config, model, optimizer):
"""Returns a gradient handler object of :class:`BaseGradientHandler` constructed from `config`,
`model` and `optimizer`.
:param config: A python dict or a :class:`colossalai.context.Config` object
containing information used in the construction of the return object
:type config: dict or :class:`colossalai.context.Config`
:param model: A model containing parameters for the gradient handler
:type model: :class:`nn.Module`
:param optimizer: An optimizer object containing parameters for the gradient handler
:type optimizer: :class:`torch.optim.Optimizer`
:return: An object of :class:`BaseGradientHandler`
:rtype: :class:`BaseGradientHandler`
"""
config_ = config.copy()
mod_type = config_.pop('type')
return GRADIENT_HANDLER.get_module(mod_type)(model, optimizer, **config_)
def build_hooks(config, trainer):
"""Returns a hook object of :class:`BaseHook` constructed from `config` and `trainer`.
:param config: A python dict or a :class:`colossalai.context.Config` object
containing information used in the construction of the return object
:type config: dict or :class:`colossalai.context.Config`
:param trainer: A :class:`Trainer` object containing parameters for the hook
:type trainer: :class:`Trainer`
:return: An object of :class:`BaseHook`
:rtype: :class:`BaseHook`
"""
config['trainer'] = trainer
return build_from_registry(config, HOOKS)
def build_transform(config):
"""Returns a transformation object of :class:`torchvision.transforms` constructed
from `config`.
:param config: A python dict or a :class:`colossalai.context.Config` object
containing information used in the construction of the return object
:type config: dict or :class:`colossalai.context.Config`
:return: An object of :class:`torchvision.transforms`
:rtype: :class:`torchvision.transforms`
"""
return build_from_registry(config, TRANSFORMS)
def build_pipe_alloc_policy(config):
"""Returns a pipeline allocation policy object constructed from `config`.
:param config: A python dict or a :class:`colossalai.context.Config` object
containing information used in the construction of the return object
:type config: dict or :class:`colossalai.context.Config`
:return: A pipeline allocation policy object
:rtype:
"""
return build_from_registry(config, PIPE_ALLOC_POLICY)
def build_data_sampler(config, dataset):
"""Returns a data sampler object of :class:`colossalai.nn.data.sampler.BaseSampler`
constructed from `config`.
:param config: A python dict or a :class:`colossalai.context.Config` object
containing information used in the construction of the return object
:type config: dict or :class:`colossalai.context.Config`
:param dataset: An object of :class:`torch.utils.data.Dataset` containing information
used in the construction of the return object
:type dataset: :class:`torch.utils.data.Dataset`
:return: An object of :class:`colossalai.nn.data.sampler.BaseSampler`
:rtype: :class:`colossalai.nn.data.sampler.BaseSampler`
"""
config_ = config.copy()
mod_type = config_.pop('type')
return SAMPLERS.get_module(mod_type)(dataset, **config_)
def build_optimizer_wrapper(config, optimizer, model=None):
"""Returns an optimizer wrapper object of :class:`torch.optim.Optimizer` constructed
from `config`, `model` and `optimizer`.
:param config: A python dict or a :class:`colossalai.context.Config` object
containing information used in the construction of the return object
:type config: dict or :class:`colossalai.context.Config`
:param optimizer: An optimizer object containing parameters for the gradient handler
:type optimizer: :class:`torch.optim.Optimizer`
:param model: A model containing parameters for the gradient handler
:type model: :class:`nn.Module`, optional
:return: An object of :class:`torch.optim.Optimizer`
:rtype: :class:`torch.optim.Optimizer`
"""
config_ = config.copy()
mod_type = config_.pop('type')
# LSG: special treatment for zeor level 3
if mod_type == 'ZeroRedundancyOptimizer_Level_3':
return OPTIMIZER_WRAPPERS.get_module(mod_type)(model, optimizer, **config_)
else:
return OPTIMIZER_WRAPPERS.get_module(mod_type)(optimizer, **config_)
def build_lr_scheduler(config, optimizer, total_steps, num_steps_per_epoch):
"""Returns a learning rate scheduler object of :class:`torch.optim.lr_scheduler`
constructed from `config`, `optimizer`, `total_steps` and `num_steps_per_epoch`.
:param config: A python dict or a :class:`colossalai.context.Config` object
containing information used in the construction of the return object
:type config: dict or :class:`colossalai.context.Config`
:param optimizer: An optimizer object containing parameters for the learning rate
scheduler
:type optimizer: :class:`torch.optim.Optimizer`
:param total_steps: Number of total steps of the learning rate scheduler
:type total_steps: int
:param num_steps_per_epoch: number of steps per epoch of the learning rate scheduler
:type num_steps_per_epoch: int
:return: An object of :class:`torch.optim.lr_scheduler`
:rtype: :class:`torch.optim.lr_scheduler`
"""
config_ = config.copy()
mod_type = config_.pop('type')
# warmup epochs will overwrite warmup steps
if 'warmup_epochs' in config_:
warmup_epochs = config_.pop('warmup_epochs')
config_['warmup_steps'] = int(num_steps_per_epoch * warmup_epochs)
return LR_SCHEDULERS.get_module(mod_type)(optimizer, total_steps, num_steps_per_epoch=num_steps_per_epoch,
**config_)

226
colossalai/builder/pipeline.py

@ -0,0 +1,226 @@
import copy
import heapq
from colossalai.builder import build_model, build_layer
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.logging import get_global_dist_logger
from colossalai.utils import set_to_cuda
def _binary_partition(weights, st, ed):
"""Returns the binary partition position of `weights`, given the start
position `st` and the end position `ed`.
:param weights: A python list to be binary partitioned
:type weights: list
:param st: the start position of the binary partition
:type st: int
:param ed: the end postition of the binary partition
:type ed: int
:return: the binary partition position of `weights`
:rtype: int
"""
w_sum = weights[ed - 1]
prefix = 0
if st > 0:
w_sum -= weights[st - 1]
prefix = weights[st - 1]
minimum = float("inf")
for idx in range(st + 1, ed):
front = weights[idx - 1] - prefix
diff = abs(w_sum - 2 * front)
if diff < minimum:
pos = idx
minimum = diff
return st, pos, ed
def _heap_addition(weights, intervals, add_cnt):
"""
"""
def _heap_push(heap, st, ed):
value = weights[ed - 1]
if st > 0:
value -= weights[st - 1]
heapq.heappush(heap, (-value, st, ed))
ret_intervals = []
heap = []
for st, ed in intervals:
_heap_push(heap, st, ed)
while add_cnt > 0:
_, st, ed = heapq.heappop(heap)
if ed - st == 1:
ret_intervals.append((st, ed))
else:
l, m, r = _binary_partition(weights, st, ed)
_heap_push(heap, l, m)
_heap_push(heap, m, r)
add_cnt -= 1
while heap:
_, st, ed = heapq.heappop(heap)
ret_intervals.append((st, ed))
ret_intervals.sort()
return ret_intervals
def _calc_partitions(weights, value):
prev = 0
prefix = 0
num_block = 0
intervals = []
for idx, w in enumerate(weights):
if weights[idx] - prefix > value:
intervals.append((prev, idx))
prev = idx
prefix = weights[idx - 1]
num_block += 1
intervals.append((prev, len(weights)))
return num_block + 1, intervals
def _binary_search(weights, num):
length = len(weights)
prefix = [1 if w == 0 else w for w in weights]
for i in range(1, length):
prefix[i] += prefix[i - 1]
lower_bound = max(weights)
upper_bound = prefix[length - 1]
while upper_bound > lower_bound:
mid = (upper_bound + lower_bound) // 2
number, _ = _calc_partitions(prefix, mid)
if number <= num:
upper_bound = mid
else:
lower_bound = mid + 1
num_block, intervals = _calc_partitions(prefix, upper_bound)
if num_block < num:
intervals = _heap_addition(prefix, intervals, num - num_block)
return intervals
def _partition_uniform(num_items, num_parts, num_chunks):
assert num_items % num_chunks == 0, \
"Layer length should be divided by the number of chunks, otherwise parameter method is recomended"
logger = get_global_dist_logger()
parts = [[] for _ in range(num_parts)]
partition_items = num_items // num_chunks
for idx in range(num_chunks):
base_idx = idx * partition_items
chunk_size = partition_items // num_parts
left = num_parts - partition_items % num_parts
if chunk_size == 0:
logger.warning("Some nodes in Pipeline have no requests")
for p in range(num_parts):
st = base_idx
base_idx += chunk_size + (p >= left)
parts[p].append((st, base_idx))
return parts
def _partition_balanced(weights, num_parts, num_chunks):
num_total = num_parts * num_chunks
num_items = len(weights)
if num_items <= num_total:
return _partition_uniform(num_items, num_parts, num_chunks)
intervals = _binary_search(weights, num_total)
current = 0
parts = [[] for _ in range(num_parts)]
for inter in intervals:
parts[current].append(inter)
current = (current + 1) % num_parts
return parts
class ModelInitializer():
def __init__(self, config, num_chunks, verbose=False):
self.num_chunks = num_chunks
self.ori_model = build_model(config)
self.layers = self.ori_model.layers_cfg
layer_length = len(self.layers)
self.verbose = verbose
self._logger = get_global_dist_logger()
self._logger.info(f"The total length of layers is {layer_length}", ranks=[0])
def model_initialize(self, partition_method='parameter'):
# Some space for initializing comunication groups
self._interval = None
self._partition_layers(method=partition_method)
models = self._build()
model = set_to_cuda(models)
return model
def _partition_layers(self, method):
pipeline_parallel_size = gpc.get_world_size(ParallelMode.PIPELINE)
pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
method = method.lower()
# Make a partition
if method == 'layer':
num_layers = len(self.layers)
self.parts = _partition_uniform(num_layers, pipeline_parallel_size, self.num_chunks)
elif method == 'parameter':
param_counts = self._count_layer_params()
# print_rank_0(param_counts)
self.parts = _partition_balanced(param_counts, pipeline_parallel_size, self.num_chunks)
else:
assert method == 'layer', "Method should be a pre-set string"
# Display the partition
if gpc.get_global_rank() == 0 and self.verbose:
log_str = 'Layer allocation after partitioning: \n'
for stage in range(pipeline_parallel_size):
num_layers = 0
for st, ed in self.parts[stage]:
num_layers += ed - st
log_str += f'\n===== stage={stage}, layers={num_layers} =====\n'
for st, ed in self.parts[stage]:
for idx, layer in enumerate(self.layers[st: ed]):
log_str += f'\t{idx + st:2d}: {layer}\n'
self._logger.info(log_str)
# Save the partition
self._interval = self.parts[pipeline_rank]
def _build(self):
"""Build model from the layer cfg according to the partition
"""
models = []
for st, ed in self._interval:
model = copy.copy(self.ori_model)
model.build_from_cfg(st, ed)
models.append(model)
return models
def _count_layer_params(self):
"""Count the number of parameters in each layer
"""
param_counts = [0] * len(self.layers)
for idx, cfg in enumerate(self.layers):
layer = build_layer(cfg)
params = filter(lambda p: p.requires_grad, layer.parameters())
param_counts[idx] = sum(p.numel() for p in params)
return param_counts

215
colossalai/checkpointing.py

@ -0,0 +1,215 @@
import os
import os.path as osp
import re
from typing import Tuple
import torch
from .context import Config
from .context.parallel_mode import ParallelMode
from .core import global_context as gpc
__all__ = [
'get_checkpoint_path',
'get_latest_checkpoint_path',
'get_latest_checkpoint_pattern',
'save_checkpoint',
'load_checkpoint'
]
def unwrap_config(config: Config):
'''
unwrap Config objects to normal dicts
'''
config_dict = dict()
for k, v in config.items():
if isinstance(v, dict):
config_dict[k] = unwrap_config(v)
else:
config_dict[k] = v
return config_dict
def _get_ranks_name():
# tensor parallel
tp_local_rank = 0
if gpc.is_initialized(ParallelMode.TENSOR):
tp_local_rank = gpc.get_local_rank(ParallelMode.TENSOR)
# pipeline parallel
pp_local_rank = 0
if gpc.is_initialized(ParallelMode.PIPELINE):
pp_local_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
ranks_name = f'tp{tp_local_rank}-pp{pp_local_rank}'
return ranks_name
def _get_standard_checkpoint_filename(epoch: int, suffix: str = ''):
ranks_name = _get_ranks_name()
return f'epoch{epoch}-{ranks_name}{suffix}.pt'
def get_checkpoint_path(checkpoint_dir: str, epoch: int, suffix: str = ''):
'''This is a function to generate the checkpoint path from the (checkpoint_dir, epoch, suffix, gpu_parallel_rank) tuple.
This is useful during generation and recuperation of the checkpoint.
:param checkpoint_dir: set up a directory for saving checkpoints
:type checkpoint_dir: str
:param epoch: epoch number (indicate how many epochs have you trained this model)
:type epoch: int
:param suffix: additional notation to specify the model or checkpoint, defaults to ''
:type suffix: str, optional
:return: checkpoint path to be generated
:rtype: path
'''
ckpt_filename = _get_standard_checkpoint_filename(epoch, suffix)
return os.path.join(checkpoint_dir, ckpt_filename)
def _ensure_directory_exists(filename: str):
# ensure the directory exists
dir = os.path.dirname(filename)
if not os.path.exists(dir):
os.makedirs(dir)
def get_latest_checkpoint_pattern(suffix: str = ''):
'''Generate Regular expression of latest checkpoint's pattern
:param suffix: additional notation to specify the model or checkpoint, defaults to ''
:type suffix: str, optional
:return: checkpoint pattern
:rtype: regular expression
'''
ranks_name = _get_ranks_name()
ckpt_pattern = re.compile(f'epoch(\d+)-{ranks_name}{suffix}\.pt')
return ckpt_pattern
def get_latest_checkpoint_path(checkpoint_dir: str, suffix: str = ''):
'''This is a function to retrieve the latest checkpoint path from the (checkpoint_dir, suffix, gpu_parallel_rank) tuple.
This is useful during recuperation of the checkpoint, especially when you do not know the epoch number.
:param checkpoint_dir: directory for saving checkpoints
:type checkpoint_dir: str
:param suffix: additional notation to specify the model or checkpoint, defaults to ''
:type suffix: str, optional
:raises FileNotFoundError: raise error when we cannot find the latest checkpoint file with inputs given
:return: the latest checkpoint path to be retrieved
:rtype: path
'''
CKPT_NAME_PAT = get_latest_checkpoint_pattern(suffix=suffix)
last_epoch = -1
assert osp.isdir(checkpoint_dir), f'{checkpoint_dir} is not a directory'
for filename in os.listdir(checkpoint_dir):
ret = CKPT_NAME_PAT.match(filename)
if ret:
epoch = int(ret[0].split('-')[0].lstrip('epoch'))
if epoch > last_epoch:
last_epoch = epoch
if last_epoch == -1:
ranks_name = _get_ranks_name()
raise FileNotFoundError(f"Cannot find the latest checkpoint file for {ranks_name} in {checkpoint_dir}")
else:
target_file = _get_standard_checkpoint_filename(last_epoch, suffix=suffix)
path = osp.join(checkpoint_dir, target_file)
return path
def save_checkpoint(checkpoint_path: str,
epoch: int,
model: torch.nn.Module,
optimizer: torch.optim.Optimizer,
lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None,
**kwargs):
'''Given a directory to store the checkpoints, saves all the training components' parameters or buffers, such as model, optimizer, lr_scheduler and etc. into a checkpoint dictionary.
This method can be used for both colosalai nn.BaseModel and normal pytorch nn.Module.
:param checkpoint_path: set up a directory for saving checkpoints
:type checkpoint_path: str
:param epoch: epoch number (indicate how many epochs have you trained this model)
:type epoch: int
:param model: model to be registered
:type model: torch.nn.Module
:param optimizer: optimizer to be registered
:type optimizer: torch.optim.Optimizer
:param lr_scheduler: lr_scheduler to be registered, defaults to None
:type lr_scheduler: torch.optim.lr_scheduler._LRScheduler, optional
'''
# for compatibility with normal pytorch nn.Module
if hasattr(model, 'state_dict_for_save_checkpoint'):
model_sd = model.state_dict_for_save_checkpoint()
else:
model_sd = model.state_dict()
# ckpt container
checkpoint = {
'epoch': epoch,
'model': model_sd,
'optimizer': optimizer.state_dict(),
**kwargs
}
if lr_scheduler is not None:
checkpoint['lr_scheduler'] = lr_scheduler.state_dict()
_ensure_directory_exists(checkpoint_path)
torch.save(checkpoint, checkpoint_path)
def load_checkpoint(checkpoint_path: str,
model: torch.nn.Module,
optimizer: torch.optim.Optimizer,
lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None,
finetune: bool = False,
strict: bool = True) -> Tuple:
'''Loads the checkpoint file.
If finetune is False, then we intend to continue/resume the training process from the checkpoint given.
So we copy parameters and buffers from state_dict into these modules(model, optimizer,lr_scheduler) and its descendants.
If finetune is True, then only the weights and buffers of model should be reload.
If strict is True, then the keys of state_dict must exactly match the keys returned by this modules state_dict() function.
:param checkpoint_path: the exact and matched checkpoint_path directory to retrieve appropriate state_dict
:type checkpoint_path: str
:param model: model to reload parameters and buffers
:type model: torch.nn.Module
:param optimizer: optimizer to recuperate
:type optimizer: torch.optim.Optimizer
:param lr_scheduler: lr_scheduler to recuperate, defaults to None
:type lr_scheduler: torch.optim.lr_scheduler._LRScheduler, optional
:param finetune: whether to finetune the model with new dataset or continue the pre-training, defaults to False
:type finetune: bool, optional
:param strict: whether to strictly enforce that the keys in
:attr:`state_dict` of the checkpoint match the names of
parameters and buffers in model., defaults to True
:type strict: bool, optional
:raises ValueError: raise error if the model/optimizer cannot successfully be recuperated
:return: (the epoch number of the checkpoint retrieved, the checkpoint retrieved)
:rtype: Tuple
'''
# Load the checkpoint.
checkpoint = torch.load(checkpoint_path, map_location='cpu')
try:
last_epoch = checkpoint.pop('epoch') if not finetune else 0
model.load_state_dict(checkpoint.pop('model'), strict=strict)
except KeyError:
raise ValueError('Checkpoint is corrupted')
if not finetune:
try:
optimizer.load_state_dict(checkpoint.pop('optimizer'))
except KeyError:
raise ValueError('Checkpoint is corrupted')
if lr_scheduler is not None and 'lr_scheduler' in checkpoint:
lr_scheduler.load_state_dict(checkpoint.pop('lr_scheduler'))
return last_epoch, checkpoint

14
colossalai/communication/__init__.py

@ -0,0 +1,14 @@
from .collective import all_gather, reduce_scatter, scatter
from .p2p import (send_forward, send_forward_recv_forward, send_backward_recv_forward,
send_backward, send_backward_recv_backward, send_forward_recv_backward,
send_forward_backward_recv_forward_backward, recv_forward, recv_backward)
from .ring import ring_forward
from .utils import send_tensor_meta, recv_tensor_meta
__all__ = [
'all_gather', 'reduce_scatter', 'scatter',
'send_forward', 'send_forward_recv_forward', 'send_forward_backward_recv_forward_backward',
'send_backward', 'send_backward_recv_backward', 'send_backward_recv_forward',
'send_forward_recv_backward', 'recv_backward', 'recv_forward',
'ring_forward', 'send_tensor_meta', 'recv_tensor_meta'
]

84
colossalai/communication/collective.py

@ -0,0 +1,84 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import torch
import torch.distributed as dist
from torch import Tensor
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.utils import get_current_device
def all_gather(tensor: Tensor, dim: int,
parallel_mode: ParallelMode) -> Tensor:
"""Gathers all tensors from the parallel group and concatenates them in a
specific dimension.
:param tensor: Tensor to be gathered
:param dim: The dimension concatenating in
:param parallel_mode: Parallel group mode used in this communication
:type tensor: Tensor
:type dim: int
:type parallel_mode: ParallelMode
:return: The tensor generated by all-gather
:rtype: Tensor
"""
depth = gpc.get_world_size(parallel_mode)
temp = tensor.clone()
shape = list(temp.shape)
shape[dim] *= depth
out = torch.empty(shape, dtype=temp.dtype, device=get_current_device())
out = list(torch.chunk(out, depth, dim=dim))
out = [val.contiguous() for val in out]
dist.all_gather(out, temp, group=gpc.get_group(parallel_mode))
out = torch.cat(out, dim=dim)
return out
def reduce_scatter(tensor: Tensor, dim: int,
parallel_mode: ParallelMode) -> Tensor:
"""Reduces all tensors then scatters it in a specific dimension to all
members in the parallel group.
:param tensor: Tensor to be reduced and scattered
:param dim: The dimension scattering in
:param parallel_mode: Parallel group mode used in this communication
:type tensor: Tensor
:type dim: int
:type parallel_mode: ParallelMode
:return: The tensor generated by reduce-scatter
:rtype: Tensor
"""
depth = gpc.get_world_size(parallel_mode)
temp = list(torch.chunk(tensor, depth, dim=dim))
temp = [val.contiguous() for val in temp]
out = torch.empty(temp[0].shape,
dtype=temp[0].dtype,
device=get_current_device())
dist.reduce_scatter(output=out,
input_list=temp,
group=gpc.get_group(parallel_mode))
return out
def scatter(tensor: Tensor, src: int, dim: int,
parallel_mode: ParallelMode) -> Tensor:
"""Scatters in a specific dimension from source rank to all ranks in
the parallel group.
:param tensor: Tensor to be scattered
:param dim: The dimension scattering in
:param parallel_mode: Parallel group mode used in this communication
:type tensor: Tensor
:type dim: int
:type parallel_mode: ParallelMode
:return: The tensor generated by scatter
:rtype: Tensor
"""
depth = gpc.get_world_size(parallel_mode)
temp = tensor.clone()
dist.broadcast(temp, src=src, group=gpc.get_group(parallel_mode))
rank = gpc.get_local_rank(parallel_mode)
out = torch.chunk(temp, depth, dim=dim)[rank].contiguous()
return out

333
colossalai/communication/p2p.py

@ -0,0 +1,333 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import torch
import torch.distributed as dist
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.utils import get_current_device
def _communicate(tensor_send_next=None,
tensor_send_prev=None,
recv_prev=False,
recv_next=False,
recv_prev_shape=None,
recv_next_shape=None,
prev_rank=None,
next_rank=None,
up_group=None,
down_group=None,
dtype=None):
"""
Adapted from megatron.p2p_communication.
Communicate tensors between stages. Used as helper method in other
communication methods that are used in pipeline schedule.
Takes the following arguments:
tensor_send_next: tensor to send to next rank (no tensor sent if
set to None).
tensor_send_prev: tensor to send to prev rank (no tensor sent if
set to None).
recv_prev: boolean for whether tensor should be received from
previous rank.
recv_next: boolean for whether tensor should be received from
next rank.
Returns:
(tensor_recv_prev, tensor_recv_next)
"""
# Create placeholder tensors for receive in forward and backward directions
# if needed.
tensor_recv_prev = None
tensor_recv_next = None
if recv_prev:
assert recv_prev_shape is not None
tensor_recv_prev = torch.empty(recv_prev_shape,
requires_grad=True,
device=get_current_device(),
dtype=dtype)
if recv_next:
assert recv_next_shape is not None
tensor_recv_next = torch.empty(recv_next_shape,
requires_grad=True,
device=get_current_device(),
dtype=dtype)
if tensor_send_prev is not None or recv_prev:
if prev_rank is None:
prev_rank = gpc.get_prev_global_rank(
ParallelMode.PIPELINE)
if up_group is None:
up_group = gpc.get_group(ParallelMode.PIPELINE_PREV)
if tensor_send_next is not None or recv_next:
if next_rank is None:
next_rank = gpc.get_next_global_rank(
ParallelMode.PIPELINE)
if down_group is None:
down_group = gpc.get_group(ParallelMode.PIPELINE_NEXT)
# rank = dist.get_rank()
rank = gpc.get_global_rank()
ops = []
if tensor_send_prev is not None:
send_prev_op = dist.broadcast(tensor_send_prev,
src=rank,
group=up_group,
async_op=True)
ops.append(send_prev_op)
if tensor_recv_prev is not None:
recv_prev_op = dist.broadcast(tensor_recv_prev,
src=prev_rank,
group=up_group,
async_op=True)
ops.append(recv_prev_op)
if tensor_recv_next is not None:
recv_next_op = dist.broadcast(tensor_recv_next,
src=next_rank,
group=down_group,
async_op=True)
ops.append(recv_next_op)
if tensor_send_next is not None:
send_next_op = dist.broadcast(tensor_send_next,
src=rank,
group=down_group,
async_op=True)
ops.append(send_next_op)
for req in ops:
req.wait()
# To protect against race condition when using batch_isend_irecv().
torch.cuda.synchronize()
return tensor_recv_prev, tensor_recv_next
def recv_forward(input_tensor_shape, prev_rank=None, up_group=None):
"""Receives the input tensor from the previous member in pipeline.
:param input_tensor_shape: The shape of the tensor to be recieved
:param prev_rank: The rank of the source of the tensor
:param up_group: Communication group including the previous member in pipeline parallel group
:type input_tensor_shape: torch.Size
:type prev_rank: int, optional
:type up_group: ProcessGroup, optional
:return: The input tensor in forward step
:rtype: Tensor
"""
if gpc.is_first_rank(ParallelMode.PIPELINE):
input_tensor = None
else:
input_tensor, _ = _communicate(recv_prev=True,
recv_prev_shape=input_tensor_shape,
prev_rank=prev_rank,
up_group=up_group)
return input_tensor
def recv_backward(output_grad_shape, next_rank=None, down_group=None):
"""Receives the grad tensor from the next member in pipeline.
:param output_grad_shape: The shape of the tensor to be recieved
:param next_rank: The rank of the source of the tensor
:param down_group: Communication group including the next member in pipeline parallel group
:type output_grad_shape: torch.Size
:type next_rank: int, optional
:type down_group: ProcessGroup, optional
:return: The grad of output tensor in forward step
:rtype: Tensor
"""
if gpc.is_last_rank(ParallelMode.PIPELINE):
output_tensor_grad = None
else:
_, output_tensor_grad = _communicate(recv_next=True,
recv_next_shape=output_grad_shape,
next_rank=next_rank,
down_group=down_group)
return output_tensor_grad
def send_forward(output_tensor,
next_rank=None,
down_group=None):
"""Sends the input tensor to the next member in pipeline.
:param output_tensor: Tensor to be sent
:param next_rank: The rank of the recipient of the tensor
:param down_group: Communication group including the next member in pipeline parallel group
:type output_tensor: Tensor
:type next_rank: int, optional
:type down_group: ProcessGroup, optional
"""
if not gpc.is_last_rank(ParallelMode.PIPELINE):
_communicate(tensor_send_next=output_tensor,
next_rank=next_rank,
down_group=down_group)
def send_backward(input_tensor_grad,
prev_rank=None,
up_group=None):
"""Sends the grad tensor to the previous member in pipeline.
:param input_tensor_grad: Tensor to be sent
:param prev_rank: The rank of the recipient of the tensor
:param up_group: Communication group including the previous member in pipeline parallel group
:type input_tensor_grad: Tensor
:type prev_rank: int, optional
:type up_group: ProcessGroup, optional
"""
if not gpc.is_first_rank(ParallelMode.PIPELINE):
_communicate(tensor_send_prev=input_tensor_grad,
prev_rank=prev_rank,
up_group=up_group)
def send_forward_recv_backward(output_tensor,
output_grad_shape,
recv_next=True,
next_rank=None,
down_group=None):
"""Batched communication operation. Sends the input tensor to the
next member in pipeline, while recieves the grad tensor from the
next member in pipeline.
:param output_tensor: Tensor to be sent
:param output_grad_shape: The shape of the tensor to be recieved
:type output_tensor: Tensor
:type output_grad_shape: torch.Size
:return: The grad of output tensor in forward step
:rtype: Tensor
"""
if gpc.is_last_rank(ParallelMode.PIPELINE):
output_tensor_grad = None
else:
_, output_tensor_grad = _communicate(tensor_send_next=output_tensor,
recv_next=recv_next,
recv_next_shape=output_grad_shape,
next_rank=next_rank,
down_group=down_group)
return output_tensor_grad
def send_backward_recv_forward(input_tensor_grad,
input_tensor_shape,
recv_prev=True,
prev_rank=None,
up_group=None):
"""Batched communication operation. Sends the grad tensor to the
previous member in pipeline, while recieves the input tensor from the
previous member in pipeline.
:param input_tensor_grad: Tensor to be sent
:param input_tensor_shape: The shape of the tensor to be recieved
:type input_tensor_grad: Tensor
:type input_tensor_shape: torch.Size
:return: The input tensor in forward step
:rtype: Tensor
"""
if gpc.is_first_rank(ParallelMode.PIPELINE):
input_tensor = None
else:
input_tensor, _ = _communicate(tensor_send_prev=input_tensor_grad,
recv_prev=recv_prev,
recv_prev_shape=input_tensor_shape,
prev_rank=prev_rank,
up_group=up_group)
return input_tensor
def send_forward_recv_forward(output_tensor,
input_tensor_shape,
recv_prev=True,
prev_rank=None,
next_rank=None,
up_group=None,
down_group=None):
"""Batched communication operation. Sends the input tensor to the
next member in pipeline, while recieves the input tensor from the
previous member in pipeline.
:param output_tensor: Tensor to be sent
:param input_tensor_shape: The shape of the tensor to be recieved
:type output_tensor: Tensor
:type input_tensor_shape: torch.Size
:return: The input tensor in forward step
:rtype: Tensor
"""
input_tensor, _ = _communicate(tensor_send_next=output_tensor,
recv_prev=recv_prev,
recv_prev_shape=input_tensor_shape,
prev_rank=prev_rank,
next_rank=next_rank,
up_group=up_group,
down_group=down_group)
return input_tensor
def send_backward_recv_backward(input_tensor_grad,
output_grad_shape,
recv_next=True,
prev_rank=None,
next_rank=None,
up_group=None,
down_group=None):
"""Batched communication operation. Sends the grad tensor to the
previous member in pipeline, while recieves the grad tensor from the
next member in pipeline.
:param input_tensor_grad: Tensor to be sent
:param output_grad_shape: The shape of the tensor to be recieved
:type input_tensor_grad: Tensor
:type output_grad_shape: torch.Size
:return: The grad of output tensor in forward step
:rtype: Tensor
"""
_, output_tensor_grad = _communicate(tensor_send_prev=input_tensor_grad,
recv_next=recv_next,
recv_next_shape=output_grad_shape,
prev_rank=prev_rank,
next_rank=next_rank,
up_group=up_group,
down_group=down_group)
return output_tensor_grad
def send_forward_backward_recv_forward_backward(output_tensor,
input_tensor_grad,
input_tensor_shape,
output_grad_shape,
recv_prev=True,
recv_next=True,
prev_rank=None,
next_rank=None,
up_group=None,
down_group=None):
"""Batched communication operation. Sends the input tensor to the next and
the grad tensor to the previous, while recieves the grad tensor from the
next and the input tensor from the previous.
:param output_tensor: Tensor sent to the next
:param input_tensor_grad: Tensor sent to the previous
:param input_tensor_shape: The shape of the tensor recieved from the previous
:param output_grad_shape: The shape of the tensor recieved from the next
:type output_tensor: Tensor
:type input_tensor_grad: Tensor
:type input_tensor_shape: torch.Size
:type output_grad_shape: torch.Size
:return: (the input tensor in forward step, the grad of output tensor in forward step)
:rtype: (Tensor, Tensor)
"""
input_tensor, output_tensor_grad = _communicate(
tensor_send_next=output_tensor,
tensor_send_prev=input_tensor_grad,
recv_prev=recv_prev,
recv_next=recv_next,
recv_prev_shape=input_tensor_shape,
recv_next_shape=output_grad_shape,
prev_rank=prev_rank,
next_rank=next_rank,
up_group=up_group,
down_group=down_group)
return input_tensor, output_tensor_grad

54
colossalai/communication/ring.py

@ -0,0 +1,54 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import torch
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.utils import get_current_device, synchronize
def ring_forward(tensor_send_next: torch.Tensor, parallel_mode: ParallelMode):
"""Sends a tensor to the next member and recieves a tensor from the previous member.
This function returns the recieved tensor from the previous member.
:param tensor_send_next: Tensor sent to next member
:param parallel_mode: Parallel group mode used in this communication
:type tensor_send_next: Tensor
:type parallel_mode: ParallelMode
:return: The tensor recieved from the previous
:rtype: Tensor
"""
buffer_shape = tensor_send_next.size()
ops = []
current_rank = gpc.get_global_rank()
tensor_recv_prev = torch.empty(buffer_shape,
requires_grad=True,
device=get_current_device(),
dtype=tensor_send_next.dtype)
# send to next rank
send_next_op = torch.distributed.P2POp(
torch.distributed.isend, tensor_send_next,
gpc.get_next_global_rank(parallel_mode))
ops.append(send_next_op)
# receive from prev rank
recv_prev_op = torch.distributed.P2POp(
torch.distributed.irecv, tensor_recv_prev,
gpc.get_prev_global_rank(parallel_mode))
ops.append(recv_prev_op)
if current_rank % 2 == 0:
ops = ops[::-1]
reqs = torch.distributed.batch_isend_irecv(ops)
for req in reqs:
req.wait()
# To protect against race condition when using batch_isend_irecv().
synchronize()
return tensor_recv_prev

73
colossalai/communication/utils.py

@ -0,0 +1,73 @@
import torch
import torch.distributed as dist
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.utils import get_current_device
def send_tensor_meta(tensor, need_meta=True, down_group=None):
"""Sends tensor meta information before sending a specific tensor.
Since the recipient must know the shape of the tensor in p2p communications,
meta information of the tensor should be sent before communications. This function
synchronizes with :func:`recv_tensor_meta`.
:param tensor: Tensor to be sent
:param need_meta: If False, meta information won't be sent
:param down_group: Communication group including the next member in pipeline parallel group
:type tensor: Tensor
:type need_meta: bool, optional
:type down_group: ProcessGroup, optional
:return: False
:rtype: bool
"""
if need_meta:
rank = gpc.get_global_rank()
if down_group is None:
down_group = gpc.get_group(ParallelMode.PIPELINE_NEXT)
tensor_kwargs = {'dtype': torch.long, 'device': get_current_device()}
send_shape = torch.tensor(tensor.size(), **tensor_kwargs)
send_ndims = torch.tensor(len(tensor.size()), **tensor_kwargs)
dist.broadcast(send_ndims, src=rank, group=down_group)
dist.broadcast(send_shape, src=rank, group=down_group)
return False
def recv_tensor_meta(tensor_shape, prev_rank=None, up_group=None):
"""Recieves tensor meta information before recieving a specific tensor.
Since the recipient must know the shape of the tensor in p2p communications,
meta information of the tensor should be recieved before communications. This function
synchronizes with :func:`send_tensor_meta`.
:param tensor_shape: The shape of the tensor to be recieved
:param prev_rank: The rank of the source of the tensor
:param up_group: Communication group including the previous member in pipeline parallel group
:type tensor_shape: torch.Size
:type prev_rank: int, optional
:type up_group: ProcessGroup, optional
:return: The shape of the tensor to be recieved
:rtype: torch.Size
"""
if tensor_shape is None:
if prev_rank is None:
prev_rank = gpc.get_prev_global_rank(
ParallelMode.PIPELINE)
if up_group is None:
up_group = gpc.get_group(ParallelMode.PIPELINE_PREV)
tensor_kwargs = {'dtype': torch.long, 'device': get_current_device()}
recv_ndims = torch.empty((), **tensor_kwargs)
dist.broadcast(recv_ndims, src=prev_rank, group=up_group)
recv_shape = torch.empty(recv_ndims, **tensor_kwargs)
dist.broadcast(recv_shape, src=prev_rank, group=up_group)
tensor_shape = torch.Size(recv_shape)
return tensor_shape

31
colossalai/constants.py

@ -0,0 +1,31 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
ALLOWED_MODES = [None, '1d', '2d', '2.5d', '3d', 'sequence']
# intializer
INITIALIZER_MAPPING = {
'data': 'Initializer_Data',
'tensor': 'Initializer_Tensor',
'pipeline': 'Initializer_Pipeline',
'embedding': 'Initializer_Embedding',
'1d': 'Initializer_1D',
'2d': 'Initializer_2D',
'2.5d': 'Initializer_2p5D',
'3d': 'Initializer_3D',
'sequence': 'Initializer_Sequence'
}
# 2D paralllel
SUMMA_DIM = 'SUMMA_DIM'
# 2.5D paralllel
TESSERACT_DIM = 'TESSERACT_DIM'
TESSERACT_DEP = 'TESSERACT_DEP'
# 3D parallel
DEPTH_3D = 'DEPTH_3D'
# Tensor parallel attributes
IS_TENSOR_PARALLEL = 'is_tensor_parallel'
TENSOR_PARALLEL_ATTRIBUTES = [IS_TENSOR_PARALLEL]

5
colossalai/context/__init__.py

@ -0,0 +1,5 @@
from .config import Config
from .parallel_context import ParallelContext
from .parallel_context import ParallelMode
from .process_group_initializer import *
from .random import *

70
colossalai/context/_utils.py

@ -0,0 +1,70 @@
import math
def set_parallel_size(obj, config: dict, key: str, attr_name: str):
if key in config:
ele = config[key]
if isinstance(ele, int):
setattr(obj, attr_name, ele)
elif isinstance(ele, dict):
setattr(obj, attr_name, ele['size'])
else:
raise NotImplementedError(
f"Parallel configuration does not support this kind of argument, please use int or dict"
)
def add_tensor_pg(pg_init, mode, size, depth=None):
if mode == '1d':
pg_init.append(dict(
type='Initializer1D',
parallel_size=size
))
elif mode == '2d':
dim = math.floor(math.sqrt(size))
pg_init.append(dict(
type='Initializer2D_Col',
summa_dim=dim
))
pg_init.append(dict(
type='Initializer2D_Row',
summa_dim=dim
))
elif mode == '2.5d':
dim = math.floor(math.sqrt(size // depth))
pg_init.append(dict(
type='Initializer_Tesseract_ROW',
tesseract_dim=dim,
tesseract_dep=depth
))
pg_init.append(dict(
type='Initializer_Tesseract_COL',
tesseract_dim=dim,
tesseract_dep=depth
))
pg_init.append(dict(
type='Initializer_Tesseract_DEP',
tesseract_dim=dim,
tesseract_dep=depth
))
pg_init.append(dict(
type='Initializer_Tesseract_XZ',
tesseract_dim=dim,
tesseract_dep=depth
))
elif mode == '3d':
dim = math.floor(math.pow(size, 1.0 / 3.0) + 0.5)
pg_init.append(dict(
type='ParallelInitializer3D_Input',
depth=dim
))
pg_init.append(dict(
type='ParallelInitializer3D_Weight',
depth=dim
))
pg_init.append(dict(
type='ParallelInitializer3D_Output',
depth=dim
))
else:
raise NotImplementedError("This kind of tensor splitting has not been implemented yet")

99
colossalai/context/config.py

@ -0,0 +1,99 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import inspect
import sys
from importlib.machinery import SourceFileLoader
from pathlib import Path
class Config(dict):
"""This is a wrapper class for dict objects so that values of which can be
accessed as attributes.
:param config: The dict object to be wrapped
:type config: dict
"""
def __init__(self, config: dict = None):
if config is not None:
for k, v in config.items():
self._add_item(k, v)
def __missing__(self, key):
raise KeyError(key)
def __getattr__(self, key):
try:
value = super(Config, self).__getitem__(key)
return value
except KeyError:
raise AttributeError(key)
def __setattr__(self, key, value):
super(Config, self).__setitem__(key, value)
def _add_item(self, key, value):
if isinstance(value, dict):
self.__setattr__(key, Config(value))
else:
self.__setattr__(key, value)
def update(self, config):
assert isinstance(config, (Config, dict)), 'can only update dictionary or Config objects.'
for k, v in config.items():
self._add_item(k, v)
return self
@staticmethod
def from_file(filename: str):
"""Reads a python file and constructs a corresponding :class:`Config` object.
:param filename: Name of the file to construct the return object
:type filename: str
:raises AssertionError: Raises an AssertionError if the file does not exist, or the file
is not .py file
:return: A :class:`Config` object constructed with information in the file
:rtype: :class:`Config`
"""
# check config path
if isinstance(filename, str):
filepath = Path(filename).absolute()
elif isinstance(filename, Path):
filepath = filename.absolute()
assert filepath.exists(), f'{filename} is not found, please check your configuration path'
# check extension
extension = filepath.suffix
assert extension == '.py', 'only .py files are supported'
# import the config as module
remove_path = False
if filepath.parent not in sys.path:
sys.path.insert(0, (filepath))
remove_path = True
module_name = filepath.stem
source_file = SourceFileLoader(fullname=str(module_name), path=str(filepath))
module = source_file.load_module()
# load into config
config = Config()
for k, v in module.__dict__.items():
if k.startswith('__') or inspect.ismodule(v) or inspect.isclass(v):
continue
else:
config._add_item(k, v)
# TODO: replace with logger warning here when logger is done
print('warning: variables which starts with __, is a module or class declaration are omitted')
# remove module
del sys.modules[module_name]
if remove_path:
sys.path.pop(0)
return config

454
colossalai/context/parallel_context.py

@ -0,0 +1,454 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import os
import random
from typing import Union
import numpy as np
import torch
import torch.distributed as dist
from colossalai.constants import ALLOWED_MODES, INITIALIZER_MAPPING
from colossalai.context.config import Config
from colossalai.registry import DIST_GROUP_INITIALIZER
from ._utils import set_parallel_size
from .parallel_mode import ParallelMode
from .random import add_seed, get_seeds, set_mode
class ParallelContext:
"""This class provides interface functions for users to get the parallel context,
such as the global rank, the local rank, the world size, etc. of each device.
:param args: The distributed arguments in the system
:type args: dict
"""
def __init__(self, args=None):
# distributed settings
self._global_ranks = dict()
self._local_ranks = dict()
self._world_sizes = dict()
self._groups = dict()
self._ranks_in_group = dict()
# load config from file
self._dist_args = args
self._config = None
# default 3D parallel args, will be overwritten during process group intialization
self.world_size = 1
self.data_parallel_size = 1
self.pipeline_parallel_size = 1
self.tensor_parallel_size = 1
@property
def config(self):
return self._config
def load_config(self, config: Union[dict, str]):
"""Loads the configuration from either a dict or a file.
:param config: Either a dict containing the configuration information or the filename
of a file containing the configuration information
:type config: dict or str
:raises TypeError: Raises a TypeError if `config` is neither a dict or a str
"""
if isinstance(config, str):
self._config = Config.from_file(config)
elif isinstance(config, dict):
self._config = Config(config)
else:
raise TypeError("Invalid type for config, only dictionary or string is supported")
def set_dist_args(self, args):
"""Sets the distributed arguments.
:param args: The distributed arguments in the system
:type args: dict
"""
self._dist_args = args
@staticmethod
def _check_parallel_mode(parallel_mode: ParallelMode):
assert isinstance(parallel_mode, ParallelMode)
def get_global_rank(self):
"""Returns the global rank of the current device.
:return: The global rank of the current device
:rtype: int
"""
return self._global_ranks[ParallelMode.GLOBAL]
def add_global_rank(self, parallel_mode: ParallelMode, rank: int):
"""Adds the global rank of the current device for `parallel_mode` to the context.
:param parallel_mode: The parallel mode for the rank
:type parallel_mode: :class:`colossalai.context.ParallelMode`
:param rank: The rank to be added
:type rank: int
:raises AssertionError: Raises an AssertionError if `parallel_mode` is not an instance
of :class:`colossalai.context.ParallelMode`
"""
self._check_parallel_mode(parallel_mode)
self._global_ranks[parallel_mode] = rank
def get_local_rank(self, parallel_mode: ParallelMode):
"""Returns the local rank of the current device.
:param parallel_mode: The chosen parallel mode
:type parallel_mode: :class:`colossalai.context.ParallelMode`
:raises AssertionError: Raises an AssertionError if `parallel_mode` is not an instance
of :class:`colossalai.context.ParallelMode`
:return: The local rank of the current device for `parallel_mode`
:rtype: int
"""
self._check_parallel_mode(parallel_mode)
return self._local_ranks[parallel_mode]
def add_local_rank(self, parallel_mode: ParallelMode, rank: int):
"""Adds the local rank of the current device for `parallel_mode` to the context.
:param parallel_mode: The parallel mode for the rank
:type parallel_mode: :class:`colossalai.context.ParallelMode`
:param rank: The rank to be added
:type rank: int
:raises AssertionError: Raises an AssertionError if `parallel_mode` is not an instance
of :class:`colossalai.context.ParallelMode`
"""
self._check_parallel_mode(parallel_mode)
self._local_ranks[parallel_mode] = rank
def get_next_global_rank(self, parallel_mode: ParallelMode):
"""Returns the global rank of the next device.
:param parallel_mode: The chosen parallel mode
:type parallel_mode: :class:`colossalai.context.ParallelMode`
:raises AssertionError: Raises an AssertionError if `parallel_mode` is not an instance
of :class:`colossalai.context.ParallelMode`
:return: The global rank of the next device for `parallel_mode`
:rtype: int
"""
self._check_parallel_mode(parallel_mode)
# get rank and world size
local_rank = self.get_local_rank(parallel_mode)
world_size = self.get_world_size(parallel_mode)
ranks_in_group = self.get_ranks_in_group(parallel_mode)
return ranks_in_group[(local_rank + 1) % world_size]
def get_prev_global_rank(self, parallel_mode: ParallelMode):
"""Returns the global rank of the previous device.
:param parallel_mode: The chosen parallel mode
:type parallel_mode: :class:`colossalai.context.ParallelMode`
:raises AssertionError: Raises an AssertionError if `parallel_mode` is not an instance
of :class:`colossalai.context.ParallelMode`
:return: The global rank of the previous device for `parallel_mode`
:rtype: int
"""
self._check_parallel_mode(parallel_mode)
# get rank and world size
local_rank = self.get_local_rank(parallel_mode)
world_size = self.get_world_size(parallel_mode)
ranks_in_group = self.get_ranks_in_group(parallel_mode)
return ranks_in_group[(local_rank - 1) % world_size]
def is_first_rank(self, parallel_mode: ParallelMode):
"""Returns a boolean value indicating whether the current device is the first one
among its group for `parallel_mode`.
:param parallel_mode: The chosen parallel mode
:type parallel_mode: :class:`colossalai.context.ParallelMode`
:raises AssertionError: Raises an AssertionError if `parallel_mode` is not an instance
of :class:`colossalai.context.ParallelMode`
:return: a boolean value indicating whether the current device is the first one
among its group for `parallel_mode`
:rtype: bool
"""
rank = self.get_local_rank(parallel_mode)
return rank == 0
def is_last_rank(self, parallel_mode: ParallelMode):
"""Returns a boolean value indicating whether the current device is the last one
among its group for `parallel_mode`.
:param parallel_mode: The chosen parallel mode
:type parallel_mode: :class:`colossalai.context.ParallelMode`
:raises AssertionError: Raises an AssertionError if `parallel_mode` is not an instance
of :class:`colossalai.context.ParallelMode`
:return: a boolean value indicating whether the current device is the last one
among its group for `parallel_mode`
:rtype: bool
"""
rank = self.get_local_rank(parallel_mode)
world_size = self.get_world_size(parallel_mode)
return rank == world_size - 1
def get_world_size(self, parallel_mode: ParallelMode):
"""Returns the world size for `parallel_mode`.
:param parallel_mode: The chosen parallel mode
:type parallel_mode: :class:`colossalai.context.ParallelMode`
:raises AssertionError: Raises an AssertionError if `parallel_mode` is not an instance
of :class:`colossalai.context.ParallelMode`
:return: The world size for `parallel_mode`
:rtype: int
"""
self._check_parallel_mode(parallel_mode)
return self._world_sizes[parallel_mode]
def add_world_size(self, parallel_mode: ParallelMode, world_size: int):
"""Adds world size for `parallel_mode`.
:param parallel_mode: The chosen parallel mode
:type parallel_mode: :class:`colossalai.context.ParallelMode`
:param world_size: The world size to be added
:type world_size: int
:raises AssertionError: Raises an AssertionError if `parallel_mode` is not an instance
of :class:`colossalai.context.ParallelMode`
"""
self._check_parallel_mode(parallel_mode)
self._world_sizes[parallel_mode] = world_size
def get_group(self, parallel_mode: ParallelMode):
"""Returns the group of the current device for `parallel_mode`.
:param parallel_mode: The chosen parallel mode
:type parallel_mode: :class:`colossalai.context.ParallelMode`
:raises AssertionError: Raises an AssertionError if `parallel_mode` is not an instance
of :class:`colossalai.context.ParallelMode`
:return: The group of the current device for `parallel_mode`
:rtype: torch.distributed.ProcessGroup
"""
self._check_parallel_mode(parallel_mode)
return self._groups[parallel_mode]
def add_group(self, parallel_mode: ParallelMode, group: dist.ProcessGroup):
"""Adds the group of the current device for `parallel_mode`.
:param parallel_mode: The chosen parallel mode
:type parallel_mode: :class:`colossalai.context.ParallelMode`
:param group: The group to be added
:type group: torch.distributed.ProcessGroup
:raises AssertionError: Raises an AssertionError if `parallel_mode` is not an instance
of :class:`colossalai.context.ParallelMode`
"""
self._check_parallel_mode(parallel_mode)
self._groups[parallel_mode] = group
def get_ranks_in_group(self, parallel_mode: ParallelMode):
"""Returns the rank of the current device for `parallel_mode` in the group.
:param parallel_mode: The chosen parallel mode
:type parallel_mode: :class:`colossalai.context.ParallelMode`
:raises AssertionError: Raises an AssertionError if `parallel_mode` is not an instance
of :class:`colossalai.context.ParallelMode`
:return: the rank of the current device for `parallel_mode` in the group
:rtype: int
"""
self._check_parallel_mode(parallel_mode)
return self._ranks_in_group[parallel_mode]
def add_ranks_in_group(self, parallel_mode: ParallelMode, ranks: list):
"""Adds the ranks of the current device for `parallel_mode` in the group.
:param parallel_mode: The chosen parallel mode
:type parallel_mode: :class:`colossalai.context.ParallelMode`
:param ranks: List of ranks to be added
:type ranks: list
:raises AssertionError: Raises an AssertionError if `parallel_mode` is not an instance
of :class:`colossalai.context.ParallelMode`
"""
self._check_parallel_mode(parallel_mode)
self._ranks_in_group[parallel_mode] = ranks
def init_global_dist(self, addr=None, port=None):
"""Initializes the global distributed environment.
:param addr: The IP address of the current device
:type addr: str, optional
:param port: The port to be used in the system of the current device
:type port: int, optional
"""
# get config
rank = self._dist_args.local_rank
world_size = self._dist_args.world_size
# default env config, overwrite by exporting
# them in your bash script
addr = os.getenv('MASTER_ADDR', 'localhost') if addr is None else addr
port = os.getenv('MASTER_PORT', '8008') if port is None else port
init_method = f'tcp://{addr}:{port}'
dist.init_process_group(backend=self._dist_args.backend,
rank=rank,
world_size=world_size,
init_method=init_method)
# None will give the default global process group for pytorch dist operations
self._register_dist(rank, world_size, None,
list(range(world_size)), ParallelMode.GLOBAL)
self._global_ranks[ParallelMode.GLOBAL] = rank
def _register_dist(self, local_rank, world_size,
process_group, ranks_in_group, mode):
self.add_local_rank(mode, local_rank)
self.add_world_size(mode, world_size)
self.add_group(mode, process_group)
self.add_ranks_in_group(mode, ranks_in_group)
def check_sanity(self):
"""Checks sanity of the parallel context.
:raises AssertionError: Raises an AssertionError if the world size does not equal to the product
of data paralle size, pipeline parallel size and tensor parallel size
"""
dps = self.data_parallel_size
pps = self.pipeline_parallel_size
tps = self.tensor_parallel_size
ws = self.world_size
assert ws == dps * pps * tps, f"Expected the world size {ws} to be equal to data parallel size ({dps}) * pipeline parallel size ({pps}) * tensor parallel size ({tps})"
def init_parallel_groups(self):
"""Initializes the parallel groups.
:raises AssertionError: Raises an AssertionError if the field paralle is not present in the config file
"""
# get rank and world size
rank = self.get_global_rank()
world_size = self.get_world_size(ParallelMode.GLOBAL)
self.world_size = world_size
assert hasattr(self.config, 'parallel'), 'Expected the field parallel to be present in the config file'
# set parallel size as attributes for global context
parallel_config = self.config.parallel
set_parallel_size(self, parallel_config, 'pipeline',
'pipeline_parallel_size')
set_parallel_size(self, parallel_config, 'tensor',
'tensor_parallel_size')
# the user should not set the data parallel size manually
# instead, it should be calculated based on other parallel config
self.data_parallel_size = self.world_size // (self.pipeline_parallel_size * self.tensor_parallel_size)
# get the tensor parallel mode and check
tensor_parallel_mode = parallel_config['tensor'].get('mode', None)
assert tensor_parallel_mode in ALLOWED_MODES, f"mode in the parallel config must be set to one of {ALLOWED_MODES}"
self.check_sanity()
pg_init = []
# LSG: init data parallel process group for compatibility with other parallel module such as zero
pg_init.append(dict(type=INITIALIZER_MAPPING['data']))
if self.pipeline_parallel_size > 1:
pg_init.append(dict(type=INITIALIZER_MAPPING['pipeline']))
pg_init.append(dict(type=INITIALIZER_MAPPING['tensor']))
# init specific tensor parallel group
if tensor_parallel_mode is not None:
tensor_parallel_cfg = parallel_config['tensor'].copy()
# remove duplicate parameters
tensor_parallel_cfg.pop('mode')
tensor_parallel_cfg.pop('size')
# add this config to initialize later
pg_init.append(dict(type=INITIALIZER_MAPPING[tensor_parallel_mode.lower()], **tensor_parallel_cfg))
# run initialization of different process groups
for initializer_cfg in pg_init:
cfg = initializer_cfg.copy()
initializer_type = cfg.pop('type')
initializer = DIST_GROUP_INITIALIZER.get_module(initializer_type)(
rank, world_size, self.config,
self.data_parallel_size,
self.pipeline_parallel_size,
self.tensor_parallel_size,
**cfg)
parallel_setting = initializer.init_dist_group()
if isinstance(parallel_setting, list):
for args in parallel_setting:
self._register_dist(*args)
else:
self._register_dist(*parallel_setting)
def is_initialized(self, parallel_mode: ParallelMode):
"""Returns a boolean value indicating whether `parallel_mode` is initialized
in the current system.
:param parallel_mode: The chosen parallel mode
:type parallel_mode: :class:`colossalai.context.ParallelMode`
:return: a boolean value indicating whether `parallel_mode` is initialized
in the current system
:rtype: bool
"""
return parallel_mode in self._groups
def destroy(self):
"""Destroys the current distributed parallel environment.
"""
for mode, group in self._groups.items():
if mode is not ParallelMode.GLOBAL:
dist.destroy_process_group(group)
# destroy global process group
dist.destroy_process_group()
def set_device(self):
"""Sets distributed processes to be bound to devices.
"""
devices_per_node = torch.cuda.device_count()
global_rank = self.get_global_rank()
device = global_rank % devices_per_node
torch.cuda.set_device(device)
print(f'process rank {global_rank} is bound to device {device}')
def set_seed(self):
"""Sets seeds for all random libraries.
"""
if hasattr(self.config, 'seed'):
seed = getattr(self.config, 'seed')
else:
seed = 2 # default seed
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
global_rank = self.get_global_rank()
if torch.cuda.is_available():
# create random seed for different parallel modes
# data parallel seed are kept the same
parallel_seed = seed
add_seed(ParallelMode.DATA, parallel_seed)
# model parallel seeds are different across ranks
pipeline_offset = self._local_ranks.get(ParallelMode.PIPELINE, 0)
# add seed for data parallel and tensor parallel only
if self.is_initialized(ParallelMode.TENSOR):
tp_rank = self.get_local_rank(ParallelMode.TENSOR)
# 100 is only to increase the diff in seeds between pipeline stages
tp_rank_with_offset = tp_rank + pipeline_offset * 1024
tp_seed = seed + tp_rank_with_offset
add_seed(ParallelMode.TENSOR, tp_seed)
set_mode(ParallelMode.DATA)
seeds = get_seeds()
seed_str = ', '.join([f'{k}: {v}' for k, v in seeds.items()])
print(f"initialized seed on rank {global_rank}, "
f"numpy: {seed}, python random: {seed}, {seed_str},"
f"the default parallel seed is {ParallelMode.DATA}.", flush=True)
else:
print(f"initialized seed on rank {global_rank}, "
f"numpy: {seed}, python random: {seed}, pytorch: {seed}", flush=True)
print('WARNING: CUDA is not available, thus CUDA RNG cannot be used to track CUDA random number states',
flush=True)

44
colossalai/context/parallel_mode.py

@ -0,0 +1,44 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from enum import Enum
# parallel modes
class ParallelMode(Enum):
"""This is an enumeration class containing all possible parallel modes.
"""
GLOBAL = 'global'
# common parallel
DATA = 'data'
# pipeline parallel
PIPELINE = 'pipe'
PIPELINE_PREV = 'pipe_prev'
PIPELINE_NEXT = 'pipe_next'
# containing all ranks in tensor parallel
TENSOR = 'tensor'
# sequence parallel
SEQUENCE = 'sequence'
# 1D Parallel
PARALLEL_1D = '1d'
# 2D parallel
PARALLEL_2D_ROW = '2d_row'
PARALLEL_2D_COL = '2d_col'
# 3D parallel
PARALLEL_3D_INPUT = '3d_input'
PARALLEL_3D_WEIGHT = '3d_weight'
PARALLEL_3D_OUTPUT = '3d_output'
# 2.5D parallel
PARALLEL_2P5D_ROW = '2p5d_row'
PARALLEL_2P5D_COL = '2p5d_col'
PARALLEL_2P5D_DEP = '2p5d_dep'
PARALLEL_2P5D_XZ = '2p5d_xz'

15
colossalai/context/process_group_initializer/__init__.py

@ -0,0 +1,15 @@
from .initializer_1d import Initializer_1D
from .initializer_2d import Initializer_2D
from .initializer_2p5d import Initializer_2p5D
from .initializer_3d import Initializer_3D
from .initializer_data import Initializer_Data
from .initializer_pipeline import Initializer_Pipeline
from .initializer_sequence import Initializer_Sequence
from .initializer_tensor import Initializer_Tensor
from .process_group_initializer import ProcessGroupInitializer
__all__ = [
'Initializer_Tensor', 'Initializer_Sequence', 'Initializer_Pipeline',
'Initializer_Data', 'Initializer_2p5D', 'Initializer_2D', 'Initializer_3D',
'Initializer_1D', 'ProcessGroupInitializer'
]

44
colossalai/context/process_group_initializer/initializer_1d.py

@ -0,0 +1,44 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import torch.distributed as dist
from colossalai.context import Config
from colossalai.core import global_context as gpc
from colossalai.registry import DIST_GROUP_INITIALIZER
from .process_group_initializer import ProcessGroupInitializer
from ..parallel_mode import ParallelMode
@DIST_GROUP_INITIALIZER.register_module
class Initializer_1D(ProcessGroupInitializer):
'''A ProcessGroupInitializer for 1d tensor parallelism.
'''
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.num_group = self.world_size // self.tensor_parallel_size
def init_dist_group(self):
'''Initialize 1D tensor parallel groups, and assign local_ranks and groups to each gpu.
:return: (local_rank, group_world_size, process_group, ranks_in_group, mode)
:rtype: tuple
'''
local_rank = None
ranks_in_group = None
process_group = None
group_world_size = None
mode = ParallelMode.PARALLEL_1D
for i in range(self.num_group):
ranks = [i * self.tensor_parallel_size + j for j in range(self.tensor_parallel_size)]
group = dist.new_group(ranks)
if self.rank in ranks:
local_rank = ranks.index(self.rank)
group_world_size = len(ranks)
process_group = group
ranks_in_group = ranks
return local_rank, group_world_size, process_group, ranks_in_group, mode

123
colossalai/context/process_group_initializer/initializer_2d.py

@ -0,0 +1,123 @@
import math
import os
import torch.distributed as dist
from colossalai.constants import SUMMA_DIM
from colossalai.registry import DIST_GROUP_INITIALIZER
from .process_group_initializer import ProcessGroupInitializer
from ..parallel_mode import ParallelMode
def _check_summa_env_var(summa_dim):
# check environment variable for SUMMA
env_summa_dim = os.environ.get(SUMMA_DIM, None)
if env_summa_dim:
assert int(env_summa_dim) == summa_dim, \
'SUMMA_DIM has been set in the current environment and ' \
'does not match with the value passed to this initialized'
else:
os.environ[SUMMA_DIM] = str(summa_dim)
class Initializer_2D_Row(ProcessGroupInitializer):
'''2d tensor parallel initialization among rows.
'''
def __init__(self, num_group, summa_dim, *args, **kwargs):
super(Initializer_2D_Row, self).__init__(*args, **kwargs)
self.num_group = num_group
self.summa_dim = summa_dim
def init_dist_group(self):
'''Initialize 2D tensor row parallel groups, and assign local_ranks and groups to each gpu.
:return: 2D tensor row parallelism's information
:rtype: tuple(local_rank, group_world_size, process_group, ranks_in_group, mode)
'''
local_rank = None
ranks_in_group = None
process_group = None
group_world_size = None
mode = ParallelMode.PARALLEL_2D_ROW
for i in range(self.num_group):
for j in range(self.summa_dim):
ranks = [i * self.tensor_parallel_size + j * self.summa_dim + k
for k in range(self.summa_dim)]
group = dist.new_group(ranks)
if self.rank in ranks:
local_rank = ranks.index(self.rank)
group_world_size = len(ranks)
process_group = group
ranks_in_group = ranks
return local_rank, group_world_size, process_group, ranks_in_group, mode
class Initializer_2D_Col(ProcessGroupInitializer):
'''2d tensor parallel initialization among cols.
'''
def __init__(self, num_group, summa_dim, *args, **kwargs):
super(Initializer_2D_Col, self).__init__(*args, **kwargs)
self.num_group = num_group
self.summa_dim = summa_dim
def init_dist_group(self):
'''Initialize 2D tensor row parallel groups, and assign local_ranks and groups to each gpu.
:return: 2D tensor col parallelism's information
:rtype: tuple(local_rank, group_world_size, process_group, ranks_in_group, mode)
'''
local_rank = None
ranks_in_group = None
process_group = None
group_world_size = None
mode = ParallelMode.PARALLEL_2D_COL
for i in range(self.num_group):
for j in range(self.summa_dim):
ranks = [i * self.tensor_parallel_size + j + k * self.summa_dim
for k in range(self.summa_dim)]
group = dist.new_group(ranks)
if self.rank in ranks:
local_rank = ranks.index(self.rank)
group_world_size = len(ranks)
process_group = group
ranks_in_group = ranks
return local_rank, group_world_size, process_group, ranks_in_group, mode
@DIST_GROUP_INITIALIZER.register_module
class Initializer_2D(ProcessGroupInitializer):
"""
Serve as the single entry point to 2D parallel initialization.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.num_group = self.world_size // self.tensor_parallel_size
self.summa_dim = int(math.sqrt(self.tensor_parallel_size))
assert self.tensor_parallel_size == self.summa_dim ** 2, \
"2D summa dim should equal to tensor parallel size ^ 0.5"
_check_summa_env_var(self.summa_dim)
self.col_initializer = Initializer_2D_Col(self.num_group, self.summa_dim, *args, **kwargs)
self.row_initializer = Initializer_2D_Row(self.num_group, self.summa_dim, *args, **kwargs)
def init_dist_group(self):
'''Initialize 2D tensor row and col parallel groups, and assign local_ranks and groups to each gpu.
:return: 2D tensor parallelism's information
:rtype: list of tuples (local_rank, group_world_size, process_group, ranks_in_group, mode)
'''
parallel_setting = []
parallel_setting.append(self.row_initializer.init_dist_group())
parallel_setting.append(self.col_initializer.init_dist_group())
return parallel_setting

255
colossalai/context/process_group_initializer/initializer_2p5d.py

@ -0,0 +1,255 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import math
import os
import torch.distributed as dist
from colossalai.constants import TESSERACT_DIM, TESSERACT_DEP
from colossalai.context import Config
from colossalai.core import global_context as gpc
from colossalai.registry import DIST_GROUP_INITIALIZER
from .process_group_initializer import ProcessGroupInitializer
from ..parallel_mode import ParallelMode
def _check_tesseract_env_var(tesseract_dim: int,
tesseract_dep: int):
# check environment variable for TESSERACT
env_tesseract_dim = os.environ.get(TESSERACT_DIM, None)
env_tesseract_dep = os.environ.get(TESSERACT_DEP, None)
if env_tesseract_dim and env_tesseract_dep:
assert int(env_tesseract_dim) == tesseract_dim, \
'TESSERACT_DIM has been set in the current environment and ' \
'does not match with the value passed to this initialized'
assert int(env_tesseract_dep) == tesseract_dep, \
'TESSERACT_DEP has been set in the current environment and ' \
'does not match with the value passed to this initialized'
else:
os.environ[TESSERACT_DIM] = str(tesseract_dim)
os.environ[TESSERACT_DEP] = str(tesseract_dep)
# i row j col k dep
class Initializer_2p5D_ROW(ProcessGroupInitializer):
'''2p5d tensor parallel initialization among rows.
'''
def __init__(self,
tesseract_dim: int,
tesseract_dep: int,
*args):
super(Initializer_2p5D_ROW, self).__init__(*args)
self.tensor_parallel_size = gpc.tensor_parallel_size
self.num_group = self.world_size // self.tensor_parallel_size
self.tesseract_dep = tesseract_dep
self.tesseract_dim = tesseract_dim
assert self.tensor_parallel_size == self.tesseract_dim ** 2 * self.tesseract_dep, \
"Tensor parallel size should be depth * dim ** 2 in 2.5D parallel"
def init_dist_group(self):
'''Initialize 2p5D tensor row parallel groups, and assign local_ranks and groups to each gpu.
:return: 2p5D tensor row parallelism's information
:rtype: tuple(local_rank, group_world_size, process_group, ranks_in_group, mode)
'''
local_rank = None
ranks_in_group = None
process_group = None
group_world_size = None
mode = ParallelMode.PARALLEL_2P5D_ROW
for h in range(self.num_group):
for j in range(self.tesseract_dim):
for k in range(self.tesseract_dep):
ranks = [h * self.tensor_parallel_size + i + self.tesseract_dim * (
j + self.tesseract_dim * k) for i in range(self.tesseract_dim)]
group = dist.new_group(ranks)
if self.rank in ranks:
local_rank = ranks.index(self.rank)
group_world_size = len(ranks)
process_group = group
ranks_in_group = ranks
return local_rank, group_world_size, process_group, ranks_in_group, mode
class Initializer_2p5D_Col(ProcessGroupInitializer):
'''2p5d tensor parallel initialization among cols.
'''
def __init__(self,
tesseract_dim: int,
tesseract_dep: int,
*args):
super(Initializer_2p5D_Col, self).__init__(*args)
self.tensor_parallel_size = gpc.tensor_parallel_size
self.num_group = self.world_size // self.tensor_parallel_size
self.tesseract_dep = tesseract_dep
self.tesseract_dim = tesseract_dim
assert self.tensor_parallel_size == self.tesseract_dim ** 2 * self.tesseract_dep, \
"Tensor parallel size should be depth * dim ** 2 in 2.5D parallel"
def init_dist_group(self):
'''Initialize 2p5D tensor col parallel groups, and assign local_ranks and groups to each gpu.
:return: 2p5D tensor col parallelism's information
:rtype: tuple(local_rank, group_world_size, process_group, ranks_in_group, mode)
'''
local_rank = None
ranks_in_group = None
process_group = None
group_world_size = None
mode = ParallelMode.PARALLEL_2P5D_COL
for h in range(self.num_group):
for i in range(self.tesseract_dim):
for k in range(self.tesseract_dep):
ranks = [h * self.tensor_parallel_size + i + self.tesseract_dim * (
j + self.tesseract_dim * k) for j in range(self.tesseract_dim)]
group = dist.new_group(ranks)
if self.rank in ranks:
local_rank = ranks.index(self.rank)
group_world_size = len(ranks)
process_group = group
ranks_in_group = ranks
return local_rank, group_world_size, process_group, ranks_in_group, mode
class Initializer_2p5D_Dep(ProcessGroupInitializer):
'''2p5D tensor parallel initialization among depths.
'''
def __init__(self,
tesseract_dim: int,
tesseract_dep: int,
*args):
super(Initializer_2p5D_Dep, self).__init__(*args)
self.tensor_parallel_size = gpc.tensor_parallel_size
self.num_group = self.world_size // self.tensor_parallel_size
self.tesseract_dep = tesseract_dep
self.tesseract_dim = tesseract_dim
assert self.tensor_parallel_size == self.tesseract_dim ** 2 * self.tesseract_dep, \
"Tensor parallel size should be depth * dim ** 2 in 2.5D parallel"
def init_dist_group(self):
'''Initialize 2p5D tensor depth parallel groups, and assign local_ranks and groups to each gpu.
:return: 2p5D tensor depth parallelism's information
:rtype: tuple(local_rank, group_world_size, process_group, ranks_in_group, mode)
'''
local_rank = None
ranks_in_group = None
process_group = None
group_world_size = None
mode = ParallelMode.PARALLEL_2P5D_DEP
for h in range(self.num_group):
for i in range(self.tesseract_dim):
for j in range(self.tesseract_dim):
ranks = [h * self.tensor_parallel_size + i + self.tesseract_dim * (
j + self.tesseract_dim * k) for k in range(self.tesseract_dep)]
group = dist.new_group(ranks)
if self.rank in ranks:
local_rank = ranks.index(self.rank)
group_world_size = len(ranks)
process_group = group
ranks_in_group = ranks
return local_rank, group_world_size, process_group, ranks_in_group, mode
# i row j col k dep
class Initializer_2p5D_XZ(ProcessGroupInitializer):
'''2p5d tensor parallel initialization among cols times dep.
'''
def __init__(self,
tesseract_dim: int,
tesseract_dep: int,
*args):
super(Initializer_2p5D_XZ, self).__init__(*args)
self.tensor_parallel_size = gpc.tensor_parallel_size
self.num_group = self.world_size // self.tensor_parallel_size
self.tesseract_dep = tesseract_dep
self.tesseract_dim = tesseract_dim
assert self.tensor_parallel_size == self.tesseract_dim ** 2 * self.tesseract_dep, \
"Tensor parallel size should be depth * dim ** 2 in 2.5D parallel"
def init_dist_group(self):
'''Initialize 2p5D tensor colXdepth parallel groups, and assign local_ranks and groups to each gpu.
:return: 2p5D tensor colXdepth parallelism's information
:rtype: tuple(local_rank, group_world_size, process_group, ranks_in_group, mode)
'''
local_rank = None
ranks_in_group = None
process_group = None
group_world_size = None
mode = ParallelMode.PARALLEL_2P5D_XZ
for h in range(self.num_group):
for i in range(self.tesseract_dim):
ranks = [h * self.tensor_parallel_size + i + self.tesseract_dim * (
j + self.tesseract_dim * k) for k in range(self.tesseract_dep) for j in
range(self.tesseract_dim)]
group = dist.new_group(ranks)
if self.rank in ranks:
local_rank = ranks.index(self.rank)
group_world_size = len(ranks)
process_group = group
ranks_in_group = ranks
return local_rank, group_world_size, process_group, ranks_in_group, mode
@DIST_GROUP_INITIALIZER.register_module
class Initializer_2p5D(ProcessGroupInitializer):
"""
Serve as the single entry point to Tesseract parallel initialization.
"""
def __init__(self,
rank: int,
world_size: int,
config: Config,
data_parallel_size: int,
pipeline_parlalel_size: int,
tensor_parallel_size: int,
depth: int
):
args = (rank, world_size, config, data_parallel_size, pipeline_parlalel_size, tensor_parallel_size)
super().__init__(*args)
self.num_group = self.world_size // self.tensor_parallel_size
self.tesseract_dim = int(math.sqrt(self.tensor_parallel_size / depth))
self.tesseract_dep = depth
assert self.tensor_parallel_size == self.tesseract_dim ** 2 * self.tesseract_dep, \
"2.5D tesseract dim should equal to (tensor parallel size / tesseract dep) ^ 0.5"
_check_tesseract_env_var(self.tesseract_dim, self.tesseract_dep)
self.col_initializer = Initializer_2p5D_Col(self.tesseract_dim, self.tesseract_dep, *args)
self.row_initializer = Initializer_2p5D_ROW(self.tesseract_dim, self.tesseract_dep, *args)
self.dep_initializer = Initializer_2p5D_Dep(self.tesseract_dim, self.tesseract_dep, *args)
self.xz_initializer = Initializer_2p5D_XZ(self.tesseract_dim, self.tesseract_dep, *args)
def init_dist_group(self):
'''Initialize 2p5D tensor row, col, depth, and colXdepth parallel groups, and assign local_ranks and groups to each gpu.
:return: Whole 2p5D tensor parallelism's information
:rtype: list of tuples (local_rank, group_world_size, process_group, ranks_in_group, mode)
'''
parallel_setting = []
parallel_setting.append(self.col_initializer.init_dist_group())
parallel_setting.append(self.row_initializer.init_dist_group())
parallel_setting.append(self.dep_initializer.init_dist_group())
parallel_setting.append(self.xz_initializer.init_dist_group())
return parallel_setting

172
colossalai/context/process_group_initializer/initializer_3d.py

@ -0,0 +1,172 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import math
import os
import torch.distributed as dist
from colossalai.constants import DEPTH_3D
from colossalai.registry import DIST_GROUP_INITIALIZER
from ..parallel_mode import ParallelMode
from .process_group_initializer import ProcessGroupInitializer
def _check_depth_env_var(depth):
# check environment variable for SUMMA
env_depth = os.environ.get(DEPTH_3D, None)
if env_depth:
assert int(env_depth) == depth, \
'SUMMA_DIM has been set in the current environment and ' \
'does not match with the value passed to this initialized'
else:
os.environ[DEPTH_3D] = str(depth)
class Initializer_3D_Input(ProcessGroupInitializer):
'''2D tensor parallel initialization among input.
'''
def __init__(self, num_group: int, depth: int, *args):
super().__init__(*args)
self.num_group = num_group
self.depth = depth
def init_dist_group(self):
'''Initialize 3D tensor parallel groups among input, and assign local_ranks and groups to each gpu.
:return: 3D tensor parallelism's information among input
:rtype: tuple(local_rank, group_world_size, process_group, ranks_in_group, mode)
'''
local_rank = None
ranks_in_group = None
process_group = None
group_world_size = None
mode = ParallelMode.PARALLEL_3D_INPUT
for h in range(self.num_group):
for i in range(self.depth):
for k in range(self.depth):
ranks = [
h * self.depth**3 + i + self.depth *
(j + self.depth * k) for j in range(self.depth)
]
group = dist.new_group(ranks)
if self.rank in ranks:
local_rank = ranks.index(self.rank)
group_world_size = len(ranks)
process_group = group
ranks_in_group = ranks
return local_rank, group_world_size, process_group, ranks_in_group, mode
class Initializer_3D_Weight(ProcessGroupInitializer):
'''3D tensor parallel initialization among weight.
'''
def __init__(self, num_group: int, depth: int, *args):
super().__init__(*args)
self.num_group = num_group
self.depth = depth
def init_dist_group(self):
'''Initialize 3D tensor parallel groups among weight, and assign local_ranks and groups to each gpu.
:return: 3D tensor parallelism's information among weight
:rtype: tuple(local_rank, group_world_size, process_group, ranks_in_group, mode)
'''
local_rank = None
ranks_in_group = None
process_group = None
group_world_size = None
mode = ParallelMode.PARALLEL_3D_WEIGHT
for h in range(self.num_group):
for k in range(self.depth):
for j in range(self.depth):
ranks = [
h * self.depth**3 + i + self.depth *
(j + self.depth * k) for i in range(self.depth)
]
group = dist.new_group(ranks)
if self.rank in ranks:
local_rank = ranks.index(self.rank)
group_world_size = len(ranks)
process_group = group
ranks_in_group = ranks
return local_rank, group_world_size, process_group, ranks_in_group, mode
class Initializer_3D_Output(ProcessGroupInitializer):
'''2D tensor parallel initialization among weight.
'''
def __init__(self, num_group: int, depth: int, *args):
super().__init__(*args)
self.num_group = num_group
self.depth = depth
def init_dist_group(self):
'''Initialize 3D tensor parallel groups among output, and assign local_ranks and groups to each gpu.
:return: 3D tensor parallelism's information among output
:rtype: tuple(local_rank, group_world_size, process_group, ranks_in_group, mode)
'''
local_rank = None
ranks_in_group = None
process_group = None
group_world_size = None
mode = ParallelMode.PARALLEL_3D_OUTPUT
for h in range(self.num_group):
for i in range(self.depth):
for j in range(self.depth):
ranks = [
h * self.depth**3 + i + self.depth *
(j + self.depth * k) for k in range(self.depth)
]
group = dist.new_group(ranks)
if self.rank in ranks:
local_rank = ranks.index(self.rank)
group_world_size = len(ranks)
process_group = group
ranks_in_group = ranks
return local_rank, group_world_size, process_group, ranks_in_group, mode
@DIST_GROUP_INITIALIZER.register_module
class Initializer_3D(ProcessGroupInitializer):
'''Serve as the single entry point to 3D parallel initialization.
'''
def __init__(self, *args):
super().__init__(*args)
self.num_group = self.world_size // self.tensor_parallel_size
self.depth = round(math.pow(self.tensor_parallel_size, 1 / 3))
assert self.tensor_parallel_size == self.depth ** 3, \
f'3D depth ({self.depth}) if not cube root of tensor parallel size ({self.tensor_parallel_size})'
_check_depth_env_var(self.depth)
self.input_initializer = Initializer_3D_Input(self.num_group,
self.depth, *args)
self.weight_initializer = Initializer_3D_Weight(
self.num_group, self.depth, *args)
self.output_initializer = Initializer_3D_Output(
self.num_group, self.depth, *args)
def init_dist_group(self):
'''Initialize 3D tensor parallel groups, and assign local_ranks and groups to each gpu.
:return: 3D tensor parallelism's information
:rtype: list of tuples (local_rank, group_world_size, process_group, ranks_in_group, mode)
'''
parallel_setting = []
parallel_setting.append(self.input_initializer.init_dist_group())
parallel_setting.append(self.weight_initializer.init_dist_group())
parallel_setting.append(self.output_initializer.init_dist_group())
return parallel_setting

41
colossalai/context/process_group_initializer/initializer_data.py

@ -0,0 +1,41 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from torch import distributed as dist
from colossalai.registry import DIST_GROUP_INITIALIZER
from .process_group_initializer import ProcessGroupInitializer
from ..parallel_mode import ParallelMode
@DIST_GROUP_INITIALIZER.register_module
class Initializer_Data(ProcessGroupInitializer):
'''A ProcessGroupInitializer for data parallelism.
'''
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.num_data_parallel_group = self.world_size // self.data_parallel_size
def init_dist_group(self):
'''Initialize data parallel groups, and assign local_ranks and groups to each gpu.
:return: data parallelism's information
:rtype: tuple (local_rank, group_world_size, process_group, ranks_in_group, mode)
'''
local_rank = None
ranks_in_group = None
process_group = None
group_world_size = None
mode = ParallelMode.DATA
for i in range(self.num_data_parallel_group):
ranks = [i + j * self.num_data_parallel_group for j in range(self.data_parallel_size)]
group = dist.new_group(ranks)
if self.rank in ranks:
local_rank = ranks.index(self.rank)
group_world_size = len(ranks)
process_group = group
ranks_in_group = ranks
return local_rank, group_world_size, process_group, ranks_in_group, mode

63
colossalai/context/process_group_initializer/initializer_pipeline.py

@ -0,0 +1,63 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from torch import distributed as dist
from colossalai.registry import DIST_GROUP_INITIALIZER
from .process_group_initializer import ProcessGroupInitializer
from ..parallel_mode import ParallelMode
@DIST_GROUP_INITIALIZER.register_module
class Initializer_Pipeline(ProcessGroupInitializer):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.data_group_size = self.world_size // self.data_parallel_size
self.pipeline_stage_size = self.data_group_size // self.pipeline_parallel_size
def init_dist_group(self):
dist_settings = list()
for i in range(self.data_parallel_size):
for j in range(self.pipeline_stage_size):
pipe_ranks = list(
range(i * self.data_group_size + j,
(i + 1) * self.data_group_size,
self.pipeline_stage_size))
pipe_group_size = len(pipe_ranks)
pipe_group = dist.new_group(pipe_ranks)
if self.rank in pipe_ranks:
local_rank = pipe_ranks.index(self.rank)
group_world_size = pipe_group_size
process_group = pipe_group
ranks_in_group = pipe_ranks
dist_settings.append(
tuple((local_rank, group_world_size,
process_group, ranks_in_group,
ParallelMode.PIPELINE)))
for k in range(pipe_group_size):
first = pipe_ranks[k]
second = pipe_ranks[(k + 1) % pipe_group_size]
ranks = [first, second]
group = dist.new_group(ranks)
if self.rank == first:
local_rank = 0
group_world_size = 2
process_group = group
ranks_in_group = ranks
dist_settings.append(
tuple((local_rank, group_world_size,
process_group, ranks_in_group,
ParallelMode.PIPELINE_NEXT)))
elif self.rank == second:
local_rank = 1
group_world_size = 2
process_group = group
ranks_in_group = ranks
dist_settings.append(
tuple((local_rank, group_world_size,
process_group, ranks_in_group,
ParallelMode.PIPELINE_PREV)))
return dist_settings

27
colossalai/context/process_group_initializer/initializer_sequence.py

@ -0,0 +1,27 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from colossalai.registry import DIST_GROUP_INITIALIZER
from .initializer_tensor import Initializer_Tensor
from .process_group_initializer import ProcessGroupInitializer
from ..parallel_mode import ParallelMode
@DIST_GROUP_INITIALIZER.register_module
class Initializer_Sequence(ProcessGroupInitializer):
'''A ProcessGroupInitializer for sequence parallelism.
'''
def __init__(self,
*args, **kwargs):
super().__init__(*args, **kwargs)
# reuse tensor parallel code
self._initializer = Initializer_Tensor(*args, **kwargs)
def init_dist_group(self):
local_rank, group_world_size, process_group, ranks_in_group, mode = self._initializer.init_dist_group()
# change mode to sequence
mode = ParallelMode.SEQUENCE
return local_rank, group_world_size, process_group, ranks_in_group, mode

41
colossalai/context/process_group_initializer/initializer_tensor.py

@ -0,0 +1,41 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import torch.distributed as dist
from colossalai.registry import DIST_GROUP_INITIALIZER
from .process_group_initializer import ProcessGroupInitializer
from ..parallel_mode import ParallelMode
@DIST_GROUP_INITIALIZER.register_module
class Initializer_Tensor(ProcessGroupInitializer):
'''A ProcessGroupInitializer for tensor parallelism.
'''
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.num_tensor_parallel_group = self.world_size // self.tensor_parallel_size
def init_dist_group(self):
'''Initialize tensor parallel groups, and assign local_ranks and groups to each gpu.
:return: tensor parallelism's information
:rtype: tuple(local_rank, group_world_size, process_group, ranks_in_group, mode)
'''
local_rank = None
ranks_in_group = None
process_group = None
group_world_size = None
mode = ParallelMode.TENSOR
for i in range(self.num_tensor_parallel_group):
ranks = [i * self.tensor_parallel_size + j for j in range(self.tensor_parallel_size)]
group = dist.new_group(ranks)
if self.rank in ranks:
local_rank = ranks.index(self.rank)
group_world_size = len(ranks)
process_group = group
ranks_in_group = ranks
return local_rank, group_world_size, process_group, ranks_in_group, mode

30
colossalai/context/process_group_initializer/process_group_initializer.py

@ -0,0 +1,30 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from abc import ABC, abstractmethod
from colossalai.context import Config
class ProcessGroupInitializer(ABC):
'''An object, knowing the parallelism configuration, that initializes parallel groups.
'''
def __init__(self,
rank: int,
world_size: int,
config: Config,
data_parallel_size: int,
pipeline_parlalel_size: int,
tensor_parallel_size: int
):
self.rank = rank
self.world_size = world_size
self.data_parallel_size = data_parallel_size
self.config = config
self.pipeline_parallel_size = pipeline_parlalel_size
self.tensor_parallel_size = tensor_parallel_size
super().__init__()
@abstractmethod
def init_dist_group(self):
pass

8
colossalai/context/random/__init__.py

@ -0,0 +1,8 @@
from ._helper import (seed, set_mode, with_seed, add_seed,
get_seeds, get_states, get_current_mode,
set_seed_states, sync_states)
__all__ = [
'seed', 'set_mode', 'with_seed', 'add_seed', 'get_seeds',
'get_states', 'get_current_mode', 'set_seed_states', 'sync_states'
]

144
colossalai/context/random/_helper.py

@ -0,0 +1,144 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import functools
from contextlib import contextmanager
import torch.cuda
from torch import Tensor
from .seed_manager import SeedManager
from ..parallel_mode import ParallelMode
_SEED_MANAGER = SeedManager()
def get_seeds():
"""Returns the seeds of the seed manager.
:return: The seeds of the seed manager
:rtype: dict
"""
return _SEED_MANAGER.seeds
def get_states(copy=False):
"""Returns the seed states of the seed manager.
:return: The seed states of the seed manager
:rtype: dict
"""
states = _SEED_MANAGER.seed_states
if copy:
new_states = dict()
for parallel_mode, state in states.items():
new_states[parallel_mode] = state.clone()
return new_states
else:
return _SEED_MANAGER.seed_states
def get_current_mode():
"""Returns the current mode of the seed manager.
:return: The current mode of the seed manager.
:rtype: :class:`torch.ByteTensor`
"""
return _SEED_MANAGER.current_mode
def add_seed(parallel_mode: ParallelMode, seed: int):
"""Adds a seed to the seed manager for `parallel_mode`.
:param parallel_mode: The chosen parallel mode
:type parallel_mode: :class:`colossalai.context.ParallelMode`
:param seed: The seed to be added
:type seed: int
:raises AssertionError: Raises an AssertionError if `parallel_mode` is not an instance of
:class:`colossalai.context.ParallelMode` or the seed for `parallel_mode` has been added
"""
_SEED_MANAGER.add_seed(parallel_mode, seed)
def set_mode(parallel_mode: ParallelMode):
"""Sets the current mode of the seed manager.
:param parallel_mode: The chosen parallel mode
:type parallel_mode: :class:`colossalai.context.ParallelMode`
"""
_SEED_MANAGER.set_mode(parallel_mode)
def set_seed_states(parallel_mode: ParallelMode, state: Tensor):
"""Sets the state of the seed manager for `parallel_mode`.
:param parallel_mode: The chosen parallel mode
:type parallel_mode: :class:`colossalai.context.ParallelMode`
:param state: the state to be set
:type state: :class:`torch.Tensor`
:raises AssertionError: Raises an AssertionError if `parallel_mode` is not found in the seed manager
"""
_SEED_MANAGER.set_state(parallel_mode, state)
def sync_states():
current_mode = get_current_mode()
current_states = torch.cuda.get_rng_state()
set_seed_states(current_mode, current_states)
@contextmanager
def seed(parallel_mode: ParallelMode):
""" A context for seed switch
Examples::
with seed(ParallelMode.DATA):
output = F.dropout(input)
"""
try:
# set to new mode
current_mode = _SEED_MANAGER.current_mode
yield _SEED_MANAGER.set_mode(parallel_mode)
finally:
# recover
_SEED_MANAGER.set_mode(current_mode)
def with_seed(func, parallel_mode: ParallelMode):
"""
A function wrapper which executes the function with a specified seed.
Examples::
# use with decorator
@with_seed(ParallelMode.DATA)
def forward(input):
return F.dropout(input)
out = forward(input)
# OR use it inline
def forward(input):
return F.dropout(input)
wrapper_forward = with_seed(forward, ParallelMode.DATA)
out = wrapped_forward(input)
"""
@functools.wraps(func)
def wrapper(*args, **kwargs):
# switch mode
current_mode = _SEED_MANAGER.current_mode
_SEED_MANAGER.set_mode(parallel_mode)
# exec func
out = func(*args, **kwargs)
# recover state
_SEED_MANAGER.set_mode(current_mode)
return out
return wrapper

74
colossalai/context/random/seed_manager.py

@ -0,0 +1,74 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import torch
from torch import Tensor
from colossalai.context.parallel_mode import ParallelMode
class SeedManager:
"""This class is a manager of all random seeds involved in the system.
"""
def __init__(self):
self._current_mode = None
self._seeds = dict()
self._seed_states = dict()
@property
def current_mode(self):
return self._current_mode
@property
def seeds(self):
return self._seeds
@property
def seed_states(self):
return self._seed_states
def set_state(self, parallel_mode: ParallelMode, state: Tensor):
"""Sets the state of the seed manager for `parallel_mode`.
:param parallel_mode: The chosen parallel mode
:type parallel_mode: :class:`colossalai.context.ParallelMode`
:param state: the state to be set
:type state: :class:`torch.Tensor`
:raises AssertionError: Raises an AssertionError if `parallel_mode` is not found in the seed manager
"""
assert parallel_mode in self._seed_states, f'Parallel mode {parallel_mode} is not found in the seed manager'
self._seed_states[parallel_mode] = state
def set_mode(self, parallel_mode: ParallelMode):
"""Sets the current mode of the seed manager.
:param parallel_mode: The chosen parallel mode
:type parallel_mode: :class:`colossalai.context.ParallelMode`
"""
if self.current_mode:
# save the current state for current mode
self._seed_states[self._current_mode] = torch.cuda.get_rng_state()
# set the new state for new mode
self._current_mode = parallel_mode
torch.cuda.set_rng_state(self._seed_states[parallel_mode])
def add_seed(self, parallel_mode: ParallelMode, seed: int):
"""Adds a seed to the seed manager for `parallel_mode`.
:param parallel_mode: The chosen parallel mode
:type parallel_mode: :class:`colossalai.context.ParallelMode`
:param seed: The seed to be added
:type seed: int
:raises AssertionError: Raises an AssertionError if `parallel_mode` is not an instance of
:class:`colossalai.context.ParallelMode` or the seed for `parallel_mode` has been added
"""
assert isinstance(
parallel_mode, ParallelMode), 'A valid ParallelMode must be provided'
assert parallel_mode not in self._seed_states, f'The seed for {parallel_mode} has been added'
current_state = torch.cuda.get_rng_state()
torch.cuda.manual_seed(seed)
self._seed_states[parallel_mode] = torch.cuda.get_rng_state()
self._seeds[parallel_mode] = seed
torch.cuda.set_rng_state(current_state)

16
colossalai/core.py

@ -0,0 +1,16 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from colossalai.context import ParallelContext
global_context = ParallelContext()
def set_global_context(context: ParallelContext):
'''Reset global context to be identical to a given :class:ParallelContext.
:param context: Parallel context to generate our global parallel context.
:type context: ParallelContext
'''
global global_context
global_context = context

7
colossalai/engine/__init__.py

@ -0,0 +1,7 @@
from .amp_type import AMP_TYPE
from ._base_engine import Engine
from .gradient_handler import *
from .schedule import *
__all__ = ['Engine']

170
colossalai/engine/_base_engine.py

@ -0,0 +1,170 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from typing import Optional
from colossalai.builder import build_gradient_handler
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.logging import get_global_dist_logger
from colossalai.nn import (ZeroRedundancyOptimizer_Level_2,
ZeroRedundancyOptimizer_Level_3)
from torch.nn import Module
from torch.nn.modules.loss import _Loss
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
from torch.utils.data import DataLoader
from .schedule import BaseSchedule, NoPipelineSchedule
class Engine:
"""Basic engine class for training and evaluation. It runs a specific process method
:meth:`step` which is based on the given :attr:`schedule` over each batch of a dataset.
:param train_dataloader: Dataloader in training
:param test_dataloader: Dataloader in evaluation
:param model: The neural network model
:param criterion: Criterion for calculating loss
:param optimizer: Optimizer for updating the parameters
:param lr_scheduler: Learning rate scheduler ajusting learning rate during the training or evaluation
:param schedule: Running schedule in :meth:`step`
:type train_dataloader: DataLoader, optional
:type test_dataloader: DataLoader, optional
:type model: Module
:type criterion: _Loss, optional
:type optimizer: Optimizer, optional
:type lr_scheduler: _LRScheduler, optional
:type schedule: BaseSchedule, optional
"""
def __init__(self,
train_dataloader: Optional[DataLoader] = None,
test_dataloader: Optional[DataLoader] = None,
model: Module = None,
criterion: _Loss = None,
optimizer: Optimizer = None,
lr_scheduler: Optional[_LRScheduler] = None,
schedule: BaseSchedule = None):
self.train_dataloader = train_dataloader
self.test_dataloader = test_dataloader
assert model is not None, "Engine requires a model"
self.model = model
self.criterion = criterion
self.optimizer = optimizer
self.lr_scheduler = lr_scheduler
self.schedule = schedule if schedule is not None \
else NoPipelineSchedule()
self._logger = get_global_dist_logger()
# build gradient handler
self._gradient_handlers = []
gradient_handler_cfg = []
if hasattr(gpc.config, 'gradient_handler'):
assert isinstance(gpc.config.gradient_handler, list), \
f'argument gradient_handler_cfg expected type list, ' \
f'but got type {type(gpc.config.gradient_handler)}'
gradient_handler_cfg = gpc.config.gradient_handler
elif isinstance(self.optimizer, (ZeroRedundancyOptimizer_Level_2,
ZeroRedundancyOptimizer_Level_3)):
gradient_handler_cfg = [dict(type='ZeROGradientHandler')]
self._logger.info(
"Training with zero is detected, ZeROGradientHandler is automatically "
"added even though not specified in the configuration",
ranks=[0])
elif gpc.is_initialized(ParallelMode.DATA) and gpc.get_world_size(
ParallelMode.DATA) > 1:
gradient_handler_cfg = [dict(type='DataParallelGradientHandler')]
self._logger.info(
"Data parallel training is detected, DataParallelGradientHandler is automatically "
"added even though not specified in the configuration",
ranks=[0])
if len(gradient_handler_cfg) == 0:
self._logger.warning(
"No gradient handler is set up, please make sure you do not need "
"to all-reduce the gradients after a training step.",
ranks=[0])
for cfg in gradient_handler_cfg:
handler = build_gradient_handler(cfg, self.model, self.optimizer)
self._gradient_handlers.append(handler)
self.schedule.initialize(self.train_dataloader, self.model,
self.criterion, self.optimizer,
self.lr_scheduler)
self.forward_only = False
def handle_gradient(self):
"""Handles all-reduce operations of gradients across different parallel groups.
"""
for handler in self._gradient_handlers:
handler.handle_gradient()
def set_dataloader(self, data: DataLoader, train: bool = True):
"""Sets dataloader in training or evaluation.
:param data: Dataloader to be set
:param train: Set training dataloader if True, otherwise evaluation dataloader
:type data: DataLoader
:type train: bool
"""
if train:
self.train_dataloader = data
else:
self.test_dataloader = data
def get_model(self):
"""Returns the neural network model in the engine.
"""
return self.model
def get_optimizer(self):
"""Returns optimizier in the engine.
"""
return self.optimizer
def get_lr_scheduler(self):
"""Returns the learning rate scheduler in the engine.
"""
return self.lr_scheduler
def train(self):
"""Sets the model to training mode.
"""
self.forward_only = False
self.schedule.train(dataloader=self.train_dataloader, mode=True)
def eval(self):
"""Sets the model to evaluation mode.
"""
self.forward_only = True
self.schedule.train(dataloader=self.test_dataloader, mode=False)
def is_train(self):
"""Returns True if it is in training, otherwise False.
"""
return not self.forward_only
def get_lr(self):
"""Gets current learning rate.
"""
return self.schedule.get_lr()
def step(self, return_loss=True):
"""A running step based on the schedule. Usually, it runs a training or
evaluation over a batch of dataset.
:param return_loss: loss will be returned if True
:type return_loss: bool
:return: (output, lablel, loss)
"""
self.schedule.zero_grad(forward_only=self.forward_only)
output, label, loss = self.schedule.forward_backward_step(
forward_only=self.forward_only, return_loss=return_loss)
if not self.forward_only:
# all reduce gradients
self.handle_gradient()
self.schedule.step()
return output, label, loss

10
colossalai/engine/amp_type.py

@ -0,0 +1,10 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from enum import Enum
class AMP_TYPE(Enum):
APEX = 'apex'
TORCH = 'torch'
PARALLEL = 'parallel'

5
colossalai/engine/gradient_handler/__init__.py

@ -0,0 +1,5 @@
from ._base_gradient_handler import BaseGradientHandler
from ._data_parallel_gradient_handler import DataParallelGradientHandler
from ._zero_gradient_handler import ZeROGradientHandler
__all__ = ['BaseGradientHandler', 'DataParallelGradientHandler', 'ZeROGradientHandler']

25
colossalai/engine/gradient_handler/_base_gradient_handler.py

@ -0,0 +1,25 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from abc import ABC, abstractmethod
class BaseGradientHandler(ABC):
"""A basic helper class to handle all-reduce operations of gradients across different parallel groups
before optimization.
:param model: Model where the gradients accumulate
:param optimizer: Optimizer for updating the parameters
:type model: Module
:type optimizer: Optimizer
"""
def __init__(self, model, optimizer):
self._model = model
self._optimizer = optimizer
@abstractmethod
def handle_gradient(self):
"""A method to accumulate gradients across different parallel groups. Users should
write their own functions or just use the functions in pre-defined subclasses.
"""
pass

48
colossalai/engine/gradient_handler/_data_parallel_gradient_handler.py

@ -0,0 +1,48 @@
#!/usr/bin/env python
import torch.distributed as dist
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from colossalai.core import global_context as gpc
from colossalai.registry import GRADIENT_HANDLER
from ._base_gradient_handler import BaseGradientHandler
from ...context.parallel_mode import ParallelMode
@GRADIENT_HANDLER.register_module
class DataParallelGradientHandler(BaseGradientHandler):
"""A helper class to handle all-reduce operations in a data parallel group.
A all-reduce collective communication will be operated in
:func:`handle_gradient` among a data parallel group.
For better performance, it bucketizes the gradients of all parameters that are
the same type to improve the efficiency of communication.
"""
def handle_gradient(self):
"""A method running a all-reduce operation in a data parallel group.
"""
# TODO: add memory buffer
if gpc.data_parallel_size > 1:
# bucketize and all-reduce
buckets = {}
# Pack the buckets.
for param in self._model.parameters():
if param.requires_grad and param.grad is not None:
tp = param.data.type()
if tp not in buckets:
buckets[tp] = []
buckets[tp].append(param)
param.main_grad = param.grad
# For each bucket, all-reduce and copy all-reduced grads.
for tp in buckets:
bucket = buckets[tp]
grads = [param.grad.data for param in bucket]
coalesced = _flatten_dense_tensors(grads)
coalesced /= gpc.get_world_size(ParallelMode.DATA)
dist.all_reduce(
coalesced, group=gpc.get_group(ParallelMode.DATA))
for buf, synced in zip(grads, _unflatten_dense_tensors(
coalesced, grads)):
buf.copy_(synced)

16
colossalai/engine/gradient_handler/_zero_gradient_handler.py

@ -0,0 +1,16 @@
from colossalai.registry import GRADIENT_HANDLER
from ._base_gradient_handler import BaseGradientHandler
@GRADIENT_HANDLER.register_module
class ZeROGradientHandler(BaseGradientHandler):
"""A helper class to handle all-reduce operations in a data parallel group.
A all-reduce collective communication will be operated in
:func:`handle_gradient` among a data parallel group.
This class is specialized with ZeRO optimization.
"""
def handle_gradient(self):
"""A method running a all-reduce operation in a data parallel group.
"""
self._optimizer.allreduce_gradients()

5
colossalai/engine/schedule/__init__.py

@ -0,0 +1,5 @@
from ._base_schedule import BaseSchedule
from ._no_pipeline import NoPipelineSchedule
from ._pipeline import PipelineSchedule
__all__ = ['BaseSchedule', 'NoPipelineSchedule', 'PipelineSchedule']

129
colossalai/engine/schedule/_base_schedule.py

@ -0,0 +1,129 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from abc import ABC, abstractmethod
import torch
from colossalai.logging import get_global_dist_logger
from colossalai.utils import get_current_device
class BaseSchedule(ABC):
"""A basic helper class to control the process of training or evaluation.
"""
def __init__(self):
self.initialized = False
self.logger = get_global_dist_logger()
@property
@abstractmethod
def num_steps(self):
"""The number of batches in training or evaluation.
"""
pass
def initialize(self,
dataloader=None,
model=None,
criterion=None,
optimizer=None,
lr_scheduler=None):
"""Initializes the schedule and set parameters before running.
:param dataloader: DataLoader in training or evaluation
:param model: The neural network model
:param criterion: Criterion for calculating loss
:param optimizer: Optimizer for updating the parameters
:param lr_scheduler: Learning rate scheduler in the process
"""
self.dataloader = dataloader
assert model is not None, "Schedule requires a model"
self.model = model
assert criterion is not None, "Schedule requires a criterion"
self.criterion = criterion
assert optimizer is not None, "Schedule requires an optimizer"
self.optimizer = optimizer
self.lr_scheduler = lr_scheduler
self.initialized = True
def check_initialized(self):
"""Checks whether the schedule is initialized.
"""
assert self.initialized, \
'Schedule is not initialized. Call schedule.initialize(...) before using it.'
def load_batch(self):
"""Loads a batch of dataset. It returns the data and labels which are
already in the same GPU as where the model's.
:return: (data, label)
:rtype: (Tensor, Tensor)
"""
self.check_initialized()
if self.data_iter is None:
raise RuntimeError('Dataloader is not defined.')
data, label = next(self.data_iter)
return self._move_to_device(data), self._move_to_device(label)
def _move_to_device(self, data):
if isinstance(data, (
tuple,
list,
)):
data = tuple([
d.to(get_current_device()).detach() for d in data
if torch.is_tensor(d)
])
elif torch.is_tensor(data):
data = data.to(get_current_device()).detach()
return data
def train(self, dataloader=None, mode=True):
"""Sets the dataloader to be used and turn the model to
training or evaluation mode.
:param dataloader: Dataloader to be used
:param mode: If True, the model will set as training mode. Otherwise, evaluation mode.
"""
self.check_initialized()
if mode:
self.model.train()
else:
self.model.eval()
if dataloader is not None:
self.dataloader = dataloader
self.data_iter = iter(dataloader)
def zero_grad(self, forward_only=False):
"""Cleans gradients with the optimizer.
"""
if not forward_only:
self.check_initialized()
self.optimizer.zero_grad()
def get_lr(self):
"""Returns the current learning rate.
"""
if self.lr_scheduler is not None:
return self.lr_scheduler.get_lr()[0]
else:
return self.optimizer.param_groups[0]['lr']
def step(self):
"""Updates the parameters and learning rate with the optimizer.
"""
self.check_initialized()
self.optimizer.step()
# update lr scheduler
if self.lr_scheduler is not None:
self.lr_scheduler.step()
@abstractmethod
def forward_backward_step(self, forward_only=False, return_loss=True):
"""The process function over a batch of dataset for training or evaluation.
:param forward_only: If True, the process won't include backward.
:param return_loss: If False, the loss won't be returned.
"""
pass

185
colossalai/engine/schedule/_no_pipeline.py

@ -0,0 +1,185 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
try:
import apex.amp as apex_amp
except:
print('apex is required for mixed precision training')
try:
import torch.cuda.amp as torch_amp
except:
print('PyTorch amp is not supported with the current PyTorch version')
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.engine.amp_type import AMP_TYPE
from colossalai.nn import (ZeroRedundancyOptimizer_Level_2,
ZeroRedundancyOptimizer_Level_3)
from ._utils import convert_to_fp16
from ._base_schedule import BaseSchedule
class NoPipelineSchedule(BaseSchedule):
"""A helper schedule class for no pipeline parallelism running environment.
During one process, it loads a batch of dataset and feeds it to the model.
After getting the output and calculating the loss, it will use :meth:`step`
to update the parameters if it is in training mode.
:param amp_type: The type of automatic mixed precision
:param amp_config: The configuration of automatic mixed procision
:type amp_type: AMP_TYPE
:type amp_config: dict
"""
def __init__(
self,
amp_type: AMP_TYPE = None,
amp_config: dict = None,
):
super().__init__()
# mixed precision training
assert amp_type is None or isinstance(amp_type, AMP_TYPE), \
'unrecognised value for argument fp16, it can only be None, torch or apex'
# LSG: check compatibility
# LSG: torch.cuda.amp and apex.amp cannot be used for tensor parallel
if gpc.is_initialized(ParallelMode.TENSOR) and gpc.get_world_size(
ParallelMode.TENSOR) > 1:
assert amp_type != AMP_TYPE.TORCH and amp_type != AMP_TYPE.APEX, \
'You can only AMP_TYPE.PARALLEL for tensor parallel training'
self.use_zero_level_2_3 = False
if amp_type is not None:
self.fp16 = True
self.amp_type = amp_type
if amp_config is not None:
assert isinstance(amp_config, dict), \
f'expected argument fp16_config to be type dictionary, but got {type(amp_config)}'
if self.amp_type == AMP_TYPE.TORCH:
# torch apex
if amp_config is None:
amp_config = dict()
self.amp_cfg = amp_config
elif self.amp_type == AMP_TYPE.APEX:
# apex amp
if amp_config is None:
amp_config = dict(opt_level='O2')
self.logger.warning(
'apex is deprecated, please consider using torch.cuda.amp instead.'
)
self.amp_cfg = amp_config
elif self.amp_type == AMP_TYPE.PARALLEL:
# use fp16 optimizer for tensor parallel training
if amp_config is None:
amp_config = dict()
self.amp_cfg = amp_config
else:
self.fp16 = False
self.amp_type = None
@property
def num_steps(self):
return len(self.dataloader)
def initialize(self,
dataloader,
model,
criterion,
optimizer,
lr_scheduler=None):
super().initialize(dataloader,
model,
criterion,
optimizer,
lr_scheduler=lr_scheduler)
if isinstance(self.optimizer, (ZeroRedundancyOptimizer_Level_2,
ZeroRedundancyOptimizer_Level_3)):
self.use_zero_level_2_3 = True
assert self.amp_type != AMP_TYPE.PARALLEL, 'ZeRO Level 2 and 3 are mutually exclusive with AMP_TYPE.PARALLEL'
if self.fp16:
if self.amp_type == AMP_TYPE.TORCH:
self._torch_amp_scaler = torch_amp.GradScaler(**self.amp_cfg)
elif self.amp_type == AMP_TYPE.APEX:
self.model, self.optimizer = apex_amp.initialize(
self.model, self.optimizer, **self.amp_cfg)
def forward_backward_step(self, forward_only=False, return_loss=True):
"""The process function that loads loads a batch of dataset and feeds it to the model.
The returned labels and loss will None if :attr:`return_loss` is False.
:return: (output, label, loss)
"""
assert forward_only or return_loss, \
'The argument \'return_loss\' has to be True when \'forward_only\' is False, but got False.'
data, label = self.load_batch()
loss = None
# LSG: leave for debug, make sure dataloader is deterministic
# if forward_only:
# img = data[0]
# rank = gpc.get_local_rank(ParallelMode.DATA)
# world_size = gpc.get_world_size(ParallelMode.DATA)
# group = gpc.get_group(ParallelMode.DATA)
# input_list = [img.clone() for _ in range(world_size)]
# output_list = [torch.empty_like(img) for _ in range(world_size)]
# output_list[rank] = img.clone()
# dist.all_to_all(output_tensor_list=output_list, input_tensor_list=input_list, group=group)
# assert torch.equal(output_list[0], output_list[1]) # and torch.equal(output_list[1], output_list[2])
# forward
if self.fp16 and self.amp_type == AMP_TYPE.TORCH:
with torch_amp.autocast():
output = self.model(*data)
if not isinstance(output, (tuple, list)):
output = (output,)
if return_loss:
loss = self.criterion(*output, *label)
else:
if self.use_zero_level_2_3 or self.amp_type == AMP_TYPE.PARALLEL:
data = convert_to_fp16(data)
output = self.model(*data)
if not isinstance(output, (tuple, list)):
output = (output,)
if return_loss:
loss = self.criterion(*output, *label)
if not forward_only:
# backward
if self.use_zero_level_2_3:
self.optimizer.backward(loss)
elif self.fp16:
if self.amp_type == AMP_TYPE.APEX:
with apex_amp.scale_loss(loss,
self.optimizer) as scaled_loss:
scaled_loss.backward()
elif self.amp_type == AMP_TYPE.TORCH:
self._torch_amp_scaler.scale(loss).backward()
elif self.amp_type == AMP_TYPE.PARALLEL:
loss = self.optimizer.scale_loss(loss)
loss.backward()
# scale back to display the original value in logs
loss.div_(self.optimizer.grad_scaler.scale)
else:
loss.backward()
if return_loss:
return output, label, loss
else:
return output, None, None
def step(self):
# step optimizer
if self.fp16 and self.amp_type == AMP_TYPE.TORCH:
self._torch_amp_scaler.step(self.optimizer)
self._torch_amp_scaler.update()
else:
self.optimizer.step()
# update lr scheduler
if self.lr_scheduler is not None:
self.lr_scheduler.step()

316
colossalai/engine/schedule/_pipeline.py

@ -0,0 +1,316 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from typing import Union
import torch.cuda
import torch.distributed as dist
from torch import Tensor
from colossalai.communication import *
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.nn import (ZeroRedundancyOptimizer_Level_2,
ZeroRedundancyOptimizer_Level_3)
from colossalai.utils import get_current_device
from ._base_schedule import BaseSchedule
from ._utils import convert_to_fp16
from ..amp_type import AMP_TYPE
def squeeze(x: Union[Tensor, tuple, list]):
if isinstance(x, (tuple, list)):
return x[0]
else:
return x
class PipelineSchedule(BaseSchedule):
"""A helper schedule class for pipeline parallelism running environment.
It uses non-interleaved 1F1B strategy. Other properties are similar as
:class:`NoPipelineSchedule`.
:param num_microbatches: The number of microbatches
:param amp_type: The type of automatic mixed precision
:param amp_config: The configuration of automatic mixed procision
:type num_microbatches: int
:type amp_type: AMP_TYPE
:type amp_config: dict
"""
def __init__(self,
num_microbatches,
amp_type: AMP_TYPE = None,
amp_config: dict = None):
super().__init__()
self.num_microbatches = num_microbatches
self.data_sync = True # close after making sure data is identical
# amp
# LSGL: amp_config is not used, but leave here for future extension
self.amp_type = amp_type
self.amp_config = amp_config
if self.amp_type is not None:
assert self.amp_type == AMP_TYPE.PARALLEL, 'We only support AMP_TYPE.PARALLEL for pipeline training for now'
def _move_to_device(self, data):
if isinstance(data, (
tuple,
list,
)):
assert len(data) == 1, "Data tuple's length in pipeline should be 1"
data = data[0]
assert torch.is_tensor(data), "Data in pipeline should be tensor"
data = data.to(get_current_device()).detach()
return data
def _sync_data(self):
if gpc.is_first_rank(ParallelMode.PIPELINE):
src_rank = gpc.get_global_rank()
dist.broadcast(
tensor=self.batch_data,
src=src_rank,
group=gpc.get_group(ParallelMode.PIPELINE_PREV)
)
dist.broadcast(
tensor=self.batch_label,
src=src_rank,
group=gpc.get_group(ParallelMode.PIPELINE_PREV)
)
if gpc.is_last_rank(ParallelMode.PIPELINE):
src_rank = gpc.get_next_global_rank(ParallelMode.PIPELINE)
dist.broadcast(
tensor=self.batch_data,
src=src_rank,
group=gpc.get_group(ParallelMode.PIPELINE_NEXT)
)
dist.broadcast(
tensor=self.batch_label,
src=src_rank,
group=gpc.get_group(ParallelMode.PIPELINE_NEXT)
)
# Pipeline schedule just puts data in memory
def load_batch(self):
self.check_initialized()
if self.data_iter is None:
raise RuntimeError('Dataloader is not defined.')
self.batch_pos = 0
data, label = next(self.data_iter)
self.batch_data, self.batch_label = \
self._move_to_device(data), self._move_to_device(label)
batch_size = self.batch_data.shape[0]
assert batch_size % self.num_microbatches == 0, \
"Batch size should divided by the number of microbatches"
self.microbatch_size = batch_size // self.num_microbatches
if self.data_sync:
self._sync_data()
def _get_data_slice(self, tensor):
return tensor[self.batch_pos: self.batch_pos + self.microbatch_size]
def load_micro_batch(self):
data = self._get_data_slice(self.batch_data)
label = self._get_data_slice(self.batch_label)
self.batch_pos += self.microbatch_size
return (data,), (label,)
@property
def num_steps(self):
return len(self.dataloader)
def initialize(self,
dataloader,
model,
criterion,
optimizer,
lr_scheduler=None):
super().initialize(dataloader,
model,
criterion,
optimizer,
lr_scheduler=lr_scheduler)
if isinstance(self.optimizer, (ZeroRedundancyOptimizer_Level_2,
ZeroRedundancyOptimizer_Level_3)):
raise TypeError(
"Pipeline schedule is currently not compatible with ZeRO Level 2 and Level 3"
)
# LSG: set default dtype to fp16 for communication
if self.amp_type == AMP_TYPE.PARALLEL:
torch.set_default_dtype(torch.half)
self.logger.info(
'default tensor dtype is set to torch.half for fp16 training',
ranks=[0])
def forward_step(self, input_tensor, return_tensors, return_loss=True):
"""Forward step for passed-in model. If it is the first stage, the input tensor
is obtained from data_iterator, otherwise the passed-in input_tensor is used.
Returns output tensor. This is a helper function and can be ignored by users.
"""
if input_tensor is None:
input_tensor, label = self.load_micro_batch()
if self.amp_type == AMP_TYPE.PARALLEL:
input_tensor = convert_to_fp16(input_tensor)
input_tensor = squeeze(input_tensor)
output_tensor = self.model(input_tensor)
output_tensor = squeeze(output_tensor)
if gpc.is_last_rank(ParallelMode.PIPELINE):
if return_loss:
input_tensor, label = self.load_micro_batch()
loss_reduced = self.criterion(output_tensor, *
label) / self.num_microbatches
return_tensors.append(
tuple((output_tensor, label[0], loss_reduced)))
return loss_reduced
else:
return_tensors.append(output_tensor)
return output_tensor
else:
return output_tensor
def backward_step(self, input_tensor, output_tensor, output_tensor_grad):
"""Backward step through the passed-in output tensor. If it is the last stage, the
output_tensor_grad is None, otherwise it is the gradients with respect to stage's output tensor.
Returns the gradients with respect to the input tensor (None if first stage).
This is a helper function and can be ignored by users.
"""
# Retain the grad on the input_tensor.
if input_tensor is not None:
input_tensor.retain_grad()
# Backward pass.
if output_tensor_grad is None and self.amp_type == AMP_TYPE.PARALLEL:
output_tensor = self.optimizer.scale_loss(output_tensor)
torch.autograd.backward(output_tensor, grad_tensors=output_tensor_grad)
# Collect the grad of the input_tensor.
input_tensor_grad = None
if input_tensor is not None:
input_tensor_grad = input_tensor.grad
return input_tensor_grad
def forward_backward_step(self, forward_only=True, return_loss=True):
"""Runs non-interleaved 1F1B schedule, with communication between pipeline stages.
Returns a tuple with losses if the last stage, an empty tuple otherwise.
:return: (output, label, loss)
"""
assert forward_only or return_loss, \
'The argument \'return_loss\' has to be True when \'forward_only\' is False, but got False.'
self.load_batch()
num_warmup_microbatches = \
(gpc.get_world_size(ParallelMode.PIPELINE) -
gpc.get_local_rank(ParallelMode.PIPELINE) - 1)
num_warmup_microbatches = min(num_warmup_microbatches,
self.num_microbatches)
num_microbatches_remaining = self.num_microbatches - num_warmup_microbatches
# Input, output tensors only need to be saved when doing backward passes
input_tensors = None
output_tensors = None
if not forward_only:
input_tensors = []
output_tensors = []
return_tensors = []
# Used for tensor meta information communication
ft_shape = None
bt_shape = None
fs_checker = True
# Run warmup forward passes.
for i in range(num_warmup_microbatches):
if not gpc.is_first_rank(ParallelMode.PIPELINE):
ft_shape = recv_tensor_meta(ft_shape)
input_tensor = recv_forward(ft_shape)
output_tensor = self.forward_step(input_tensor,
return_tensors,
return_loss=return_loss)
if not gpc.is_last_rank(ParallelMode.PIPELINE):
bt_shape = output_tensor.shape
fs_checker = send_tensor_meta(output_tensor, fs_checker)
send_forward(output_tensor)
if not forward_only:
input_tensors.append(input_tensor)
output_tensors.append(output_tensor)
# Before running 1F1B, need to receive first forward tensor.
# If all microbatches are run in warmup / cooldown phase, then no need to
# receive this tensor here.
if num_microbatches_remaining > 0:
if not gpc.is_first_rank(ParallelMode.PIPELINE):
ft_shape = recv_tensor_meta(ft_shape)
input_tensor = recv_forward(ft_shape)
# Run 1F1B in steady state.
for i in range(num_microbatches_remaining):
last_iteration = (i == (num_microbatches_remaining - 1))
output_tensor = self.forward_step(input_tensor,
return_tensors,
return_loss=return_loss)
if forward_only:
send_forward(output_tensor)
if not last_iteration:
input_tensor = recv_forward(ft_shape)
else:
output_tensor_grad = send_forward_recv_backward(
output_tensor, bt_shape)
# Add input_tensor and output_tensor to end of list.
input_tensors.append(input_tensor)
output_tensors.append(output_tensor)
# Pop input_tensor and output_tensor from the start of the list for
# the backward pass.
input_tensor = input_tensors.pop(0)
output_tensor = output_tensors.pop(0)
input_tensor_grad = self.backward_step(input_tensor,
output_tensor,
output_tensor_grad)
if last_iteration:
input_tensor = None
send_backward(input_tensor_grad)
else:
input_tensor = send_backward_recv_forward(
input_tensor_grad, ft_shape)
# Run cooldown backward passes.
if not forward_only:
for i in range(num_warmup_microbatches):
input_tensor = input_tensors.pop(0)
output_tensor = output_tensors.pop(0)
output_tensor_grad = recv_backward(bt_shape)
input_tensor_grad = self.backward_step(input_tensor,
output_tensor,
output_tensor_grad)
send_backward(input_tensor_grad)
if len(return_tensors) > 0:
if return_loss:
output, label, loss = tuple(map(list, zip(*return_tensors)))
return (torch.cat(output, dim=0),
torch.cat(label, dim=0),
sum(loss))
else:
return tuple((torch.cat(return_tensors, dim=0), None, None))
else:
return tuple((None, None, None))

16
colossalai/engine/schedule/_utils.py

@ -0,0 +1,16 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from typing import Union, List
from torch import Tensor
def convert_to_fp16(data: Union[Tensor, List[Tensor]]):
if isinstance(data, Tensor):
ret = data.half()
elif isinstance(data, (list, tuple)):
ret = [val.half() for val in data]
else:
raise TypeError(f"Expected argument 'data' to be a Tensor or a list/tuple of Tensor, but got {type(data)}")
return ret

371
colossalai/initialize.py

@ -0,0 +1,371 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import argparse
import pprint
import random
from pathlib import Path
from typing import Callable, Iterable, Optional, Union
import numpy as np
import torch
from torch.utils.data import DataLoader
from colossalai.engine import AMP_TYPE, NoPipelineSchedule, PipelineSchedule
from colossalai.logging import get_global_dist_logger, init_global_dist_logger
from colossalai.nn import DataParallelSampler
from colossalai.nn.model.base_model import BaseModel
from .builder import (ModelInitializer, build_dataset, build_loss,
build_lr_scheduler, build_model, build_optimizer,
build_optimizer_wrapper)
from .context import Config, ParallelMode
from .core import global_context as gpc
from .utils import get_current_device, sync_model_param_in_dp
def parse_args():
'''Reads user command line and uses an argument parser to parse the input arguments.
Input arguments include configuration, host, port, world size, local rank, backend for torch.distributed.
:return: call the parse arguments function
:rtype: Namespace
'''
parser = argparse.ArgumentParser()
parser.add_argument('--config', type=str, help='path to the config file')
parser.add_argument('--host',
type=str,
default=None,
help='the master address for distributed training')
parser.add_argument('--port',
type=str,
default=None,
help='the master port for distributed training')
parser.add_argument('--world_size', type=int, help='world size for ')
parser.add_argument('--local_rank',
type=int,
help='rank for the default process group')
parser.add_argument('--backend',
type=str,
default='nccl',
help='backend for torch.distributed')
return parser.parse_args()
def init_dist(config: Union[str, dict] = None,
local_rank: int = None,
world_size: int = None,
host: str = None,
port: str = None,
backend: str = None):
'''This function first parses the configuration arguments, using :func:parse_args() in case one of the input arguments are not given.
Then initialize and set distributed environment by calling global_context's functions.
:param config: config file or config file path are both acceptable
:type config: Union[str, dict], optional
:param local_rank: rank for the default process group, defaults to None
:type local_rank: int, optional
:param world_size: world size of GPUs, defaults to None
:type world_size: int, optional
:param host: the master address for distributed training, defaults to None
:type host: str, optional
:param port: the master port for distributed training, defaults to None
:type port: str, optional
:param backend: backend for torch.distributed, defaults to None
:type backend: str, optional
:raises Exception: raise exception when config type is wrong
'''
args = [config, local_rank, world_size, host, port, backend]
arg_given = [arg is not None for arg in args]
if not all(arg_given):
args = parse_args()
if config is None:
config = args.config
if local_rank is None:
local_rank = args.local_rank
if world_size is None:
world_size = args.world_size
if host is None:
host = args.host
if port is None:
port = args.port
if backend is None:
backend = args.backend
args = Config(
dict(config=config,
host=host,
port=port,
world_size=world_size,
local_rank=local_rank,
backend=backend))
# set distributed settings
dist_args = Config(
dict(local_rank=args.local_rank,
world_size=args.world_size,
backend=args.backend))
gpc.set_dist_args(dist_args)
# set config
if isinstance(args.config, dict):
cfg = args.config
elif isinstance(args.config, (str, Path)):
cfg = Config.from_file(args.config)
else:
raise Exception('Config type error: {}'.format(type(args.config)))
gpc.load_config(cfg)
# init dist groups
gpc.init_global_dist(args.host, args.port)
gpc.init_parallel_groups()
# init dist logger
init_global_dist_logger()
# set cuda device
if torch.cuda.is_available():
gpc.set_device()
def get_dataloader(dataset, seed=1024, add_sampler_if_possible=False, **kwargs):
'''Set up a deterministic dataloader (also configure seed workers, samplers and whether shuffle or not)
.. note: when pipeline parallel is enabled, shuffle cannot be True
as it will result in mismatch between input data on the 1st
stage and label on the last stage
:param dataset: a :class:utils.data.dataset dataset
:param seed: random worker seed, defaults to 1024
:type seed: int, optional
:param add_sampler_if_possible: [description], defaults to False
:type add_sampler_if_possible: bool, optional
:return: a :class:utils.data.dataset dataloader
:rtype: torch.utils.data.dataset
'''
_kwargs = kwargs.copy()
if 'shuffle' in _kwargs:
shuffle = _kwargs.pop('shuffle')
else:
shuffle = False
if add_sampler_if_possible and gpc.is_initialized(ParallelMode.DATA) and gpc.get_world_size(ParallelMode.DATA) > 1:
sampler = DataParallelSampler(dataset, shuffle=shuffle)
else:
sampler = None
# Deterministic dataloader
def seed_worker(worker_id):
worker_seed = seed
np.random.seed(worker_seed)
torch.manual_seed(worker_seed)
random.seed(worker_seed)
if sampler is None:
return DataLoader(dataset,
worker_init_fn=seed_worker,
shuffle=shuffle,
**_kwargs)
else:
return DataLoader(dataset,
sampler=sampler,
worker_init_fn=seed_worker,
**_kwargs)
def initialize(config: Union[str, dict] = None,
local_rank: int = None,
world_size: int = None,
host: str = None,
port: str = None,
backend: str = None,
train_dataloader: Optional[Union[Iterable, Callable]] = None,
test_dataloader: Optional[Union[Iterable, Callable]] = None,
):
'''Core function that initializes distributed environment, logger, cudnn, data, model, loss function, optimizer, and lr_scheduler(their configs are in gpc.config).
:param config: config file or config file path are both acceptable
:type config: Union[str, dict], optional
:param local_rank: rank for the default process group, defaults to None
:type local_rank: int, optional
:param world_size: world size of GPUs, defaults to None
:type world_size: int, optional
:param host: the master address for distributed training, defaults to None
:type host: str, optional
:param port: the master port for distributed training, defaults to None
:type port: str, optional
:param backend: backend for torch.distributed, defaults to None
:type backend: str, optional
:param train_dataloader: If None, the config is used to build a dataloder; Else, it should be a dataloader object or a function with no arguments which can build a dataloader, defaults to None
:type train_dataloader: Optional[Union[Iterable, Callable]], optional
:param test_dataloader: If None, the config is used to build a dataloder; Else, it should be a dataloader object or a function with no arguments which can build a dataloader, defaults to None
:type test_dataloader: Optional[Union[Iterable, Callable]], optional
:return: (model, train_dataloader, test_dataloader, criterion, optimizer, schedule, lr_scheduler)
:rtype: tuple
'''
# initialize distributed environment
init_dist(config=config,
local_rank=local_rank,
world_size=world_size,
host=host,
port=port,
backend=backend)
# init logger
logger = get_global_dist_logger()
logger.info(f'Distributed environment is initialized, '
f'data parallel size: {gpc.data_parallel_size}, pipeline parallel size: {gpc.pipeline_parallel_size}, '
f'tensor parallel size: {gpc.tensor_parallel_size}', ranks=[0])
# print config
logger.info(f"\n========== Your Config ========\n"
f"{pprint.pformat(gpc.config)}\n"
f"================================", ranks=[0])
# cudnn
cudnn_benchmark = gpc.config.get('cudnn_benchmark', True)
cudnn_deterministic = gpc.config.get('cudnn_deterministic', False)
torch.backends.cudnn.benchmark = cudnn_benchmark
torch.backends.cudnn.deterministic = cudnn_deterministic
logger.info(
f"cuDNN benchmark = {cudnn_benchmark}, deterministic = {cudnn_deterministic}", ranks=[0])
# set seed, cuda seed is only set when cuda is avail
gpc.set_seed()
# return_items = list()
# check fp16 and zero
should_convert_model_to_half = False
should_wrap_fp16_optimizer = False
should_wrap_zero_optimizer_level_2_3 = False
if hasattr(gpc.config, 'fp16'):
fp16_mode = gpc.config.fp16.mode
if fp16_mode == AMP_TYPE.PARALLEL:
should_convert_model_to_half = True
should_wrap_fp16_optimizer = True
if hasattr(gpc.config, 'zero'):
should_wrap_zero_optimizer_level_2_3 = True
zero_type = gpc.config.zero.type
if zero_type in ['ZeroRedundancyOptimizer_Level_2', 'ZeroRedundancyOptimizer_Level_3']:
should_convert_model_to_half = True
assert not should_wrap_fp16_optimizer, \
'AMP_TYPE.PARALLEL is mutually exclusive with zero level 2 and 3'
# build model
logger.info('Building model ...', ranks=[0])
assert hasattr(
gpc.config, 'model'), "Build error: configuration 'model' is missing"
if gpc.pipeline_parallel_size > 1:
model = ModelInitializer(gpc.config.model, 1, verbose=True)
model = model.model_initialize()
else:
model = build_model(gpc.config.model)
if isinstance(model, BaseModel):
model.build_from_cfg()
model = model.to(get_current_device())
sync_model_param_in_dp(model)
logger.info('Model is created', ranks=[0])
if should_convert_model_to_half:
model = model.half()
logger.info("Model is cast to fp16", ranks=[0])
# training data
if callable(train_dataloader):
logger.info(
f'Build train data loader from {train_dataloader}', ranks=[0])
train_dataloader = train_dataloader()
if train_dataloader is None and hasattr(gpc.config, 'train_data'):
logger.info('Preparing data ...', ranks=[0])
# assert hasattr(gpc.config, 'train_data'), "Build error: configuration 'train_data' is missing."
train_dataset = build_dataset(gpc.config.train_data.dataset)
logger.info('Train dataset is ready.', ranks=[0])
train_dataloader = get_dataloader(train_dataset,
gpc.config.get('seed', 1024),
True,
**gpc.config.train_data.dataloader,
)
logger.info(
f'Loaded {len(train_dataset)} samples in {len(train_dataloader)} batches for training', ranks=[0])
if callable(test_dataloader):
logger.info(
f'Build test data loader from {test_dataloader}', ranks=[0])
test_dataloader = test_dataloader()
# testing data, allowed to be None
if test_dataloader is None and hasattr(gpc.config, 'test_data'):
test_dataset = build_dataset(gpc.config.test_data.dataset)
test_dataloader = get_dataloader(
test_dataset, add_sampler_if_possible=True, **gpc.config.test_data.dataloader)
logger.info(
f'Loaded {len(test_dataset)} samples in {len(test_dataloader)} batches for testing', ranks=[0])
# build loss function
assert hasattr(gpc.config, 'loss'), \
'Build error: configuration \'loss\' is missing.'
criterion = build_loss(gpc.config.loss)
logger.info('Loss function is created', ranks=[0])
# build optimizer
assert hasattr(gpc.config, 'optimizer'), \
"Build error: configuration 'optimizer' is missing."
optim_type = gpc.config.optimizer.type
is_pytorch_native_zero_level_1 = optim_type == 'ZeroRedundancyOptimizer'
if is_pytorch_native_zero_level_1:
original_cfg_copy = gpc.config.optimizer.copy()
original_cfg_copy.pop('type')
cfg = dict(type=optim_type, process_group=gpc.get_group(
ParallelMode.DATA), **original_cfg_copy)
optimizer = build_optimizer(cfg, model)
else:
optimizer = build_optimizer(gpc.config.optimizer, model)
if should_wrap_zero_optimizer_level_2_3:
optimizer = build_optimizer_wrapper(gpc.config.zero, optimizer, model)
if should_wrap_fp16_optimizer:
# replace the field mode with type
fp16_cfg = gpc.config.fp16.copy()
amp_type = fp16_cfg.pop('mode')
assert amp_type == AMP_TYPE.PARALLEL, 'FP Optimizer should only be used for AMP_TYPE.PARALLEL'
fp16_cfg['type'] = 'FP16Optimizer'
optimizer = build_optimizer_wrapper(fp16_cfg, optimizer)
logger.info('Optimizer is created', ranks=[0])
lr_scheduler = None
if hasattr(gpc.config, 'lr_scheduler'):
if hasattr(gpc.config, 'num_steps'):
total_steps = gpc.config.num_steps
elif hasattr(gpc.config, 'num_epochs'):
total_steps = int(gpc.config.num_epochs * len(train_dataloader))
else:
raise Exception(
'Please specify training stopping criterion num_steps or num_epochs in your configuration.'
)
lr_scheduler = build_lr_scheduler(gpc.config.lr_scheduler, optimizer,
total_steps, len(train_dataloader))
logger.info('Learning rate scheduler is created', ranks=[0])
# pipeline or no pipeline schedule
if hasattr(gpc.config, 'fp16'):
amp_type = gpc.config.fp16.mode
amp_cfg = gpc.config.fp16.copy()
amp_cfg.pop('mode')
else:
amp_type = None
amp_cfg = None
if gpc.is_initialized(ParallelMode.PIPELINE) and gpc.get_world_size(ParallelMode.PIPELINE) > 1:
assert hasattr(gpc.config,
'schedule'), "Config 'schedule' not found in your configuration file for pipeline parallel training"
schedule = PipelineSchedule(
amp_type=amp_type, amp_config=amp_cfg, **gpc.config.schedule.copy())
else:
schedule = NoPipelineSchedule(amp_type=amp_type, amp_config=amp_cfg)
return model, train_dataloader, test_dataloader, criterion, optimizer, schedule, lr_scheduler

26
colossalai/logging/__init__.py

@ -0,0 +1,26 @@
from colossalai.core import global_context as gpc
from .logging import DistributedLogger
__all__ = ['get_global_dist_logger', 'get_dist_logger', 'DistributedLogger', 'init_global_dist_logger']
_GLOBAL_LOGGER: DistributedLogger = None
def get_dist_logger(name, level='INFO', root_path: str = None, mode='a'):
return DistributedLogger(name=name, level=level, root_path=root_path, mode=mode)
def get_global_dist_logger():
assert _GLOBAL_LOGGER is not None, 'Global distributed logger is not initialized'
return _GLOBAL_LOGGER
def init_global_dist_logger():
rank = gpc.get_global_rank()
if hasattr(gpc.config, 'logging'):
logger = get_dist_logger(name=f'rank_{rank}', **gpc.config.logging)
else:
logger = get_dist_logger(name=f'rank_{rank}', level='INFO')
global _GLOBAL_LOGGER
assert _GLOBAL_LOGGER is None, 'Global distributed logger has already been initialized'
_GLOBAL_LOGGER = logger

97
colossalai/logging/logging.py

@ -0,0 +1,97 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import logging
from pathlib import Path
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
_FORMAT = 'colossalai - %(name)s - %(asctime)s %(levelname)s: %(message)s'
logging.basicConfig(level=logging.INFO, format=_FORMAT)
class DistributedLogger:
"""This is a distributed event logger class essentially based on :class:`logging`.
:param name: The name of the logger
:type name: str
:param level: The threshold for the logger. Logging messages which are less severe than `level`
will be ignored
:type level: str
:param root_path: The root path where logs are stored
:type root_path: str, optional
:param mode: The mode that the file is opened in. Defaults to 'a'
:type mode: str, optional
"""
def __init__(self, name, level='INFO', root_path: str = None, mode='a'):
self._logger = logging.getLogger(name)
self._logger.setLevel(getattr(logging, level))
if root_path is not None:
log_root_path = Path(root_path)
# create path if not exists
log_root_path.mkdir(parents=True, exist_ok=True)
log_path = log_root_path.joinpath(f'{name}.log')
file_handler = logging.FileHandler(log_path, mode)
file_handler.setLevel(getattr(logging, level))
formatter = logging.Formatter(_FORMAT)
file_handler.setFormatter(formatter)
self._logger.addHandler(file_handler)
def _log(self, level, message: str, parallel_mode: ParallelMode = ParallelMode.GLOBAL, ranks: list = None):
if ranks is None:
getattr(self._logger, level)(message)
else:
local_rank = gpc.get_local_rank(parallel_mode)
if local_rank in ranks:
getattr(self._logger, level)(message)
def info(self, message: str, parallel_mode: ParallelMode = ParallelMode.GLOBAL, ranks: list = None):
"""Stores an info log message.
:param message:
:type message:
:param parallel_mode:
:type parallel_mode:
:param ranks:
:type ranks:
"""
self._log('info', message, parallel_mode, ranks)
def warning(self, message: str, parallel_mode: ParallelMode = ParallelMode.GLOBAL, ranks: list = None):
"""Stores a warning log message.
:param message: The message to be logged
:type message: str
:param parallel_mode: The parallel mode used for logging. Defaults to ParallelMode.GLOBAL
:type parallel_mode: :class:`colossalai.context.parallel_mode.ParallelMode`
:param ranks: List of parallel ranks
:type ranks: list
"""
self._log('warning', message, parallel_mode, ranks)
def debug(self, message: str, parallel_mode: ParallelMode = ParallelMode.GLOBAL, ranks: list = None):
"""Stores a debug log message.
:param message: The message to be logged
:type message: str
:param parallel_mode: The parallel mode used for logging. Defaults to ParallelMode.GLOBAL
:type parallel_mode: :class:`colossalai.context.parallel_mode.ParallelMode`
:param ranks: List of parallel ranks
:type ranks: list
"""
self._log('debug', message, parallel_mode, ranks)
def error(self, message: str, parallel_mode: ParallelMode = ParallelMode.GLOBAL, ranks: list = None):
"""Stores an error log message.
:param message: The message to be logged
:type message: str
:param parallel_mode: The parallel mode used for logging. Defaults to ParallelMode.GLOBAL
:type parallel_mode: :class:`colossalai.context.parallel_mode.ParallelMode`
:param ranks: List of parallel ranks
:type ranks: list
"""
self._log('error', message, parallel_mode, ranks)

6
colossalai/nn/__init__.py

@ -0,0 +1,6 @@
from .data import *
from .layer import *
from .loss import *
from .lr_scheduler import *
from .model import *
from .optimizer import *

3
colossalai/nn/data/__init__.py

@ -0,0 +1,3 @@
from .caltech101_dataset import Caltech101Dataset
from .cifar10_dataset import CIFAR10Dataset
from .sampler import *

14
colossalai/nn/data/_utils.py

@ -0,0 +1,14 @@
import numpy as np
def pil_img_to_numpy(pil_img):
"""convert a PIL image to numpy nd-array
:param pil_img: a PIL image
:type pil_img: PIL.Image
:return: a nd-array
:rtype: numpy.ndarray
"""
np_img = np.array(pil_img)
np_img = np.rollaxis(np_img, 2) # HWC to CHW
return np_img

17
colossalai/nn/data/base_dataset.py

@ -0,0 +1,17 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from abc import ABC
from torch.utils.data import Dataset
from torchvision.transforms import transforms
from colossalai.builder import build_transform
class BaseDataset(Dataset, ABC):
def __init__(self, transform_pipeline: list):
transform_list = [build_transform(cfg) for cfg in transform_pipeline]
transform = transforms.Compose(transform_list)
self._transform_pipeline = transform

43
colossalai/nn/data/caltech101_dataset.py

@ -0,0 +1,43 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import torch.distributed as dist
from torchvision.datasets import Caltech101
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.registry import DATASETS
from .base_dataset import BaseDataset
@DATASETS.register_module
class Caltech101Dataset(BaseDataset):
"""`Caltech 101 <http://www.vision.caltech.edu/Image_Datasets/Caltech101/>`_ Dataset.
:param transform_pipeline: A list of functions' config, which takes in an PIL image
and returns a transformed version
:type transform_pipeline: list
"""
def __init__(self, transform_pipeline: list, *args, **kwargs):
super().__init__(transform_pipeline)
if gpc.is_initialized(ParallelMode.GLOBAL) and gpc.get_global_rank() != 0:
dist.barrier()
self._dataset = Caltech101(
transform=self._transform_pipeline, *args, **kwargs)
if gpc.is_initialized(ParallelMode.GLOBAL) and gpc.get_global_rank() == 0:
dist.barrier()
def __len__(self):
return len(self._dataset)
def __getitem__(self, item):
"""
:param item: Index
:type item: int
:return: ((image,), (target,)) where the type of target specified by target_type.
:rtype: tuple
"""
img, label = self._dataset.__getitem__(item)
return (img,), (label,)

44
colossalai/nn/data/cifar10_dataset.py

@ -0,0 +1,44 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import torch.distributed as dist
from torchvision.datasets import CIFAR10
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.registry import DATASETS
from .base_dataset import BaseDataset
@DATASETS.register_module
class CIFAR10Dataset(BaseDataset):
"""`CIFAR10 <https://www.cs.toronto.edu/~kriz/cifar.html>`_ Dataset.
:param transform_pipeline: A list of functions' config, which takes in an PIL image
and returns a transformed version
:type transform_pipeline: list
"""
def __init__(self, transform_pipeline: list, *args, **kwargs):
super().__init__(transform_pipeline)
if gpc.is_initialized(ParallelMode.GLOBAL) and gpc.get_global_rank() != 0:
dist.barrier()
self._dataset = CIFAR10(transform=self._transform_pipeline,
*args,
**kwargs)
if gpc.is_initialized(ParallelMode.GLOBAL) and gpc.get_global_rank() == 0:
dist.barrier()
def __len__(self):
return len(self._dataset)
def __getitem__(self, item):
"""
:param item: Index
:type item: int
:return: ((image,), (target,)) where the type of target specified by target_type.
:rtype: tuple
"""
img, label = self._dataset.__getitem__(item)
return (img,), (label,)

4
colossalai/nn/data/sampler/__init__.py

@ -0,0 +1,4 @@
from .base_sampler import BaseSampler
from .data_parallel_sampler import DataParallelSampler
__all__ = ['BaseSampler', 'DataParallelSampler']

19
colossalai/nn/data/sampler/base_sampler.py

@ -0,0 +1,19 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from abc import ABC, abstractmethod
class BaseSampler(ABC):
def __init__(self, dataset, batch_size):
self.dataset = dataset
self.batch_size = batch_size
@abstractmethod
def __len__(self):
pass
@abstractmethod
def __iter__(self):
pass

102
colossalai/nn/data/sampler/data_parallel_sampler.py

@ -0,0 +1,102 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
# adpated from torch.utils.data.DistributedSampler
import math
from typing import TypeVar, Iterator
import torch
from torch.utils.data import Sampler, Dataset
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.registry import SAMPLERS
T_co = TypeVar('T_co', covariant=True)
@SAMPLERS.register_module
class DataParallelSampler(Sampler):
"""A data sampler for distributed data parallelism
:param dataset: a Dataset instance
:type dataset: torch.utils.data.Dataset
:param shuffle: whether to shuffle data, defaults to False
:type shuffle: bool, optional
:param seed: the random seed, defaults to 0
:type seed: int, optional
:param drop_last: set to True to drop the last incomplete batch, if the dataset size is not divisible by the batch size. If False and the size of dataset is not divisible by the batch size, then the last batch will be smaller, defaults to False
:type drop_last: bool, optional
"""
def __init__(self,
dataset: Dataset,
shuffle: bool = False,
seed: int = 0,
drop_last: bool = False) -> None:
self.dataset = dataset
self.num_replicas = gpc.get_world_size(ParallelMode.DATA)
self.rank = gpc.get_local_rank(ParallelMode.DATA)
self.epoch = 0
self.drop_last = drop_last
# If the dataset length is evenly divisible by # of replicas, then there
# is no need to drop any data, since the dataset will be split equally.
# type: ignore[arg-type]
if self.drop_last and len(self.dataset) % self.num_replicas != 0:
# Split to nearest available length that is evenly divisible.
# This is to ensure each rank receives the same amount of data when
# using this Sampler.
self.num_samples = math.ceil(
# `type:ignore` is required because Dataset cannot provide a default __len__
# see NOTE in pytorch/torch/utils/data/sampler.py
(len(self.dataset) - self.num_replicas) / \
self.num_replicas # type: ignore[arg-type]
)
else:
self.num_samples = math.ceil(
len(self.dataset) / self.num_replicas) # type: ignore[arg-type]
self.total_size = self.num_samples * self.num_replicas
self.shuffle = shuffle
self.seed = seed
def __iter__(self) -> Iterator[T_co]:
if self.shuffle:
# deterministically shuffle based on epoch and seed
g = torch.Generator()
g.manual_seed(self.seed + self.epoch)
# type: ignore[arg-type]
indices = torch.randperm(len(self.dataset), generator=g).tolist()
else:
indices = list(range(len(self.dataset))) # type: ignore[arg-type]
if not self.drop_last:
# add extra samples to make it evenly divisible
padding_size = self.total_size - len(indices)
if padding_size <= len(indices):
indices += indices[:padding_size]
else:
indices += (indices * math.ceil(padding_size /
len(indices)))[:padding_size]
else:
# remove tail of data to make it evenly divisible.
indices = indices[:self.total_size]
assert len(indices) == self.total_size
# subsample
indices = indices[self.rank:self.total_size:self.num_replicas]
assert len(indices) == self.num_samples
return iter(indices)
def __len__(self) -> int:
return self.num_samples
def set_epoch(self, epoch: int) -> None:
r"""Sets the epoch for this sampler. When :attr:`shuffle=True`, this ensures all replicas
use a different random ordering for each epoch. Otherwise, the next iteration of this
sampler will yield the same ordering.
:param epoch: Epoch number.
:type epoch: int
"""
self.epoch = epoch

9
colossalai/nn/layer/__init__.py

@ -0,0 +1,9 @@
from .parallel_1d import *
from .parallel_2d import *
from .parallel_2p5d import *
from .parallel_3d import *
from .parallel_sequence import *
from .parallel_vision_transformer import *
from .vanilla_resnet import *
from .vanilla_vision_transformer import *
from .wrapper import *

63
colossalai/nn/layer/_common_utils.py

@ -0,0 +1,63 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import math
import torch
from torch import Tensor
from torch import nn
from colossalai.utils import checkpoint
from colossalai.constants import IS_TENSOR_PARALLEL
def divide(numerator, denominator):
""" only allow exact division """
assert numerator % denominator == 0, \
'{} is not divisible by {}'.format(numerator, denominator)
return numerator // denominator
def gelu(x: Tensor) -> Tensor:
"""Implementation of the gelu activation function.
For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
"""
return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
def swish(x: Tensor) -> Tensor:
return x * torch.sigmoid(x)
ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish}
def set_tensor_parallel_attribute(param):
if not hasattr(param, IS_TENSOR_PARALLEL):
setattr(param, IS_TENSOR_PARALLEL, True)
class CheckpointModule(nn.Module):
def __init__(self, checkpoint: bool = True):
super().__init__()
self.checkpoint = checkpoint
self._use_checkpoint = checkpoint
def _forward(self, *args):
raise NotImplementedError(
'CheckpointModule should implement _forward method instead of origin forward')
def forward(self, *args):
if self._use_checkpoint:
return checkpoint(self._forward, *args)
else:
return self._forward(*args)
def train(self, mode: bool = True):
self._use_checkpoint = self.checkpoint
return super().train(mode=mode)
def eval(self):
self._use_checkpoint = False
return super().eval()

138
colossalai/nn/layer/_parallel_utilities.py

@ -0,0 +1,138 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import torch
import torch.distributed as dist
from colossalai.core import global_context as gpc
def _reduce(input_, parallel_mode):
# skip if only one rank involved
if gpc.get_world_size(parallel_mode) == 1:
return input_
dist.all_reduce(input_, group=gpc.get_group(parallel_mode))
return input_
def _split(input_, parallel_mode, dim=-1):
# skip if only one rank involved
world_size = gpc.get_world_size(parallel_mode)
if world_size == 1:
return input_
# Split along last dimension.
dim_size = input_.size(dim)
assert dim_size % world_size == 0, \
f'The dimension to split ({dim_size}) is not a multiple of world size ({world_size}), ' \
f'cannot split tensor evenly'
tensor_list = torch.split(input_, dim_size // world_size, dim=dim)
rank = gpc.get_local_rank(parallel_mode)
output = tensor_list[rank].contiguous()
return output
def _gather(input_, parallel_mode, dim=-1):
# skip if only one rank involved
world_size = gpc.get_world_size(parallel_mode)
if world_size == 1:
return input_
# all gather
rank = gpc.get_local_rank(parallel_mode)
tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
tensor_list[rank] = input_
torch.distributed.all_gather(tensor_list, input_, group=gpc.get_group(parallel_mode))
# concat
output = torch.cat(tensor_list, dim=dim).contiguous()
return output
class _ReduceGrad(torch.autograd.Function):
"""Pass the input to the model parallel region."""
@staticmethod
def symbolic(graph, input_):
return input_
@staticmethod
def forward(ctx, input_, parallel_mode):
ctx.mode = parallel_mode
return input_
@staticmethod
def backward(ctx, grad_output):
return _reduce(grad_output, ctx.mode), None
class _ReduceInput(torch.autograd.Function):
"""All-reduce the input from the model parallel region."""
@staticmethod
def symbolic(graph, input_):
return _reduce(input_)
@staticmethod
def forward(ctx, input_, parallel_mode):
return _reduce(input_, parallel_mode)
@staticmethod
def backward(ctx, grad_output):
return grad_output, None
class _SplitForwardGatherBackward(torch.autograd.Function):
"""Split the input and keep only the corresponding chuck to the rank."""
@staticmethod
def symbolic(graph, input_):
return _split(input_)
@staticmethod
def forward(ctx, input_, parallel_mode, dim):
ctx.mode = parallel_mode
ctx.dim = dim
return _split(input_, parallel_mode, dim)
@staticmethod
def backward(ctx, grad_output):
return _gather(grad_output, ctx.mode, ctx.dim), None, None
class _GatherForwardSplitBackward(torch.autograd.Function):
"""Gather the input from model parallel region and concatinate."""
@staticmethod
def symbolic(graph, input_):
return _gather(input_)
@staticmethod
def forward(ctx, input_, parallel_mode, dim):
ctx.mode = parallel_mode
ctx.dim = dim
return _gather(input_, parallel_mode, dim)
@staticmethod
def backward(ctx, grad_output):
return _split(grad_output, ctx.mode, ctx.dim), None, None
def reduce_grad(input_, parallel_mode):
return _ReduceGrad.apply(input_, parallel_mode)
def reduce_input(input_, parallel_mode):
return _ReduceInput.apply(input_, parallel_mode)
def split_forward_gather_backward(input_, parallel_mode, dim):
return _SplitForwardGatherBackward.apply(input_, parallel_mode, dim)
def gather_forward_split_backward(input_, parallel_mode, dim):
return _GatherForwardSplitBackward.apply(input_, parallel_mode, dim)

27
colossalai/nn/layer/base_layer.py

@ -0,0 +1,27 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import torch.nn as nn
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
class ParallelLayer(nn.Module):
def __init__(self):
super().__init__()
self.data_parallel_rank = 0 if not gpc.is_initialized(ParallelMode.DATA) else gpc.get_local_rank(
ParallelMode.DATA)
self.data_parallel_size = 1 if not gpc.is_initialized(ParallelMode.DATA) else gpc.get_world_size(
ParallelMode.DATA)
self.tensor_parallel_rank = 0 if not gpc.is_initialized(ParallelMode.TENSOR) else gpc.get_local_rank(
ParallelMode.TENSOR)
self.tensor_parallel_size = 1 if not gpc.is_initialized(ParallelMode.TENSOR) else gpc.get_world_size(
ParallelMode.TENSOR)
self.pipeline_parallel_rank = 0 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_local_rank(
ParallelMode.PIPELINE)
self.pipeline_parallel_size = 1 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_world_size(
ParallelMode.PIPELINE)

5
colossalai/nn/layer/parallel_1d/__init__.py

@ -0,0 +1,5 @@
from .layers import Linear1D_Col, Linear1D_Row
__all__ = [
'Linear1D_Col', 'Linear1D_Row',
]

15
colossalai/nn/layer/parallel_1d/_utils.py

@ -0,0 +1,15 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from .._common_utils import divide
def vocab_range_from_per_partition_vocab_size(per_partition_vocab_size, rank):
index_f = rank * per_partition_vocab_size
index_l = index_f + per_partition_vocab_size
return index_f, index_l
def vocab_range_from_global_vocab_size(global_vocab_size, rank, world_size):
per_partition_vocab_size = divide(global_vocab_size, world_size)
return vocab_range_from_per_partition_vocab_size(per_partition_vocab_size, rank)

166
colossalai/nn/layer/parallel_1d/layers.py

@ -0,0 +1,166 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
from torch import Tensor
from torch.nn.parameter import Parameter
from typing import Tuple
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.registry import LAYERS
from colossalai.utils import get_current_device
from .._common_utils import divide
from .._parallel_utilities import reduce_grad, reduce_input, gather_forward_split_backward, \
split_forward_gather_backward
from ..base_layer import ParallelLayer
class Linear1D_Col(ParallelLayer):
"""Linear layer with column parallelism.
The linear layer is defined as :math:`Y = XA + b`. A is parallelized along
its second dimension as :math:`A = [A_1, ..., A_p]`.
:param in_features: first dimension of matrix A.
:type in_features: int
:param output_size: second dimension of matrix A.
:type output_size: int
:param bias: If true, add bias, defaults to True
:type bias: bool, optional
:param dtype: The dtype of parameters, defaults to None
:type dtype: torch.dtype, optional
:param gather_output: If true, call all-gether on output and make Y avaiable
to all GPUs, otherwise, every GPU will have its output
which is :math:`Y_i = XA_i`, defaults to False
:type gather_output: bool, optional
"""
def __init__(self,
in_features: int,
output_size: int,
bias: bool = True,
dtype: torch.dtype = None,
gather_output: bool = False):
super().__init__()
# Keep input parameters
self.input_size = in_features
self.output_size = output_size
self.gather_output = gather_output
self.skip_bias_add = not bias
world_size = gpc.get_world_size(ParallelMode.PARALLEL_1D)
self.output_size_per_partition = divide(output_size, world_size)
# Parameters.
# Initialize weight.
factory_kwargs = {'device': get_current_device(), 'dtype': dtype}
self.weight = Parameter(torch.empty(
self.output_size_per_partition, self.input_size,
**factory_kwargs))
if bias:
self.bias = Parameter(torch.empty(
self.output_size_per_partition,
**factory_kwargs))
# Always initialize bias to zero.
with torch.no_grad():
self.bias.zero_()
else:
self.register_parameter('bias', None)
def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]:
# Set up backprop all-reduce.
input_parallel = reduce_grad(input_, ParallelMode.PARALLEL_1D)
# Matrix multiply.
bias = self.bias if not self.skip_bias_add else None
output_parallel = F.linear(input_parallel, self.weight, bias)
if self.gather_output:
# All-gather across the partitions.
output = gather_forward_split_backward(
output_parallel, ParallelMode.PARALLEL_1D, dim=-1)
else:
output = output_parallel
if self.skip_bias_add:
return output, self.bias
else:
return output
@LAYERS.register_module
class Linear1D_Row(ParallelLayer):
""" Linear layer with row parallelism
:param in_features: size of each input sample
:type in_features: int
:param out_features: size of each output sample
:type out_features: int
:param bias: If set to ``False``, the layer will not learn an additive bias, defaults to True
:type bias: bool, optional
:param dtype: The dtype of parameters, defaults to None
:type dtype: torch.dtype, optional
:param parallel_input: If set to ``False``, it's assumed that the input is splitted, defaults to False
:type parallel_input: bool, optional
"""
def __init__(self,
in_features: int,
out_features: int,
bias: bool = True,
dtype: torch.dtype = None,
parallel_input: bool = False
):
super().__init__()
# Keep input parameters
self.in_features = in_features
self.out_features = out_features
self.parallel_input = parallel_input
self.skip_bias_add = not bias
# Divide the weight matrix along the last dimension.
world_size = gpc.get_world_size(ParallelMode.PARALLEL_1D)
self.input_size_per_partition = divide(in_features, world_size)
# Parameters.
# Initialize weight.
factory_kwargs = {'device': get_current_device(), 'dtype': dtype}
self.weight = Parameter(torch.empty(
self.out_features,
self.input_size_per_partition,
**factory_kwargs))
if bias:
self.bias = Parameter(torch.empty(
self.out_features,
**factory_kwargs
))
# Always initialize bias to zero.
with torch.no_grad():
self.bias.zero_()
else:
self.register_parameter('bias', None)
def reset_parameters(self) -> None:
init.xavier_normal_(self.weight)
def forward(self, input_: Tensor) -> Tensor:
# Set up backprop all-reduce.
if self.parallel_input:
input_ = input_
else:
input_ = split_forward_gather_backward(
input_, ParallelMode.PARALLEL_1D, dim=-1)
output_parallel = F.linear(input_, self.weight)
output = reduce_input(output_parallel, ParallelMode.PARALLEL_1D)
if not self.skip_bias_add:
output = output + self.bias
return output

11
colossalai/nn/layer/parallel_2d/__init__.py

@ -0,0 +1,11 @@
from ._operation import Matmul_AB_2D, Matmul_ABT_2D, Matmul_ATB_2D, Add_Bias_2D, matmul_2d
from ._transformer import TransformerMLP2D, TransformerSelfAttention2D, TransformerLayer2D
from ._vit import ViTMLP2D, ViTSelfAttention2D, ViTHead2D, ViTPatchEmbedding2D, ViTTokenFuser2D, ViTInputSplitter2D
from .layers import Linear2D, LayerNorm2D
__all__ = [
'Matmul_AB_2D', 'Matmul_ABT_2D', 'Matmul_ATB_2D', 'Add_Bias_2D', 'matmul_2d',
'TransformerMLP2D', 'TransformerSelfAttention2D', 'TransformerLayer2D',
'ViTMLP2D', 'ViTSelfAttention2D', 'ViTHead2D', 'ViTPatchEmbedding2D', 'ViTTokenFuser2D', 'ViTInputSplitter2D',
'Linear2D', 'LayerNorm2D'
]

522
colossalai/nn/layer/parallel_2d/_operation.py

@ -0,0 +1,522 @@
from typing import Any, Tuple
import torch
import torch.distributed as dist
from torch import Tensor
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.utils import get_current_device
def matmul_2d(a,
b,
summa_dim,
out_shape,
row_rank=None,
col_rank=None,
row_parallel_mode=ParallelMode.PARALLEL_2D_ROW,
col_parallel_mode=ParallelMode.PARALLEL_2D_COL,
):
"""Matrix multiplication for 2D parallelism
:param a: matrix :math:`A`
:type a: torch.tensor
:param b: matrix :math:`B`
:type b: torch.tensor
:param summa_dim: dimension of SUMMA fo 2D parallelism
:type summa_dim: int
:param out_shape: shape of output tensor
:type out_shape: tuple
:param row_rank: the rank of row, defaults to None
:type row_rank: int, optional
:param col_rank: the rank of column, defaults to None
:type col_rank: int, optional
:param row_parallel_mode: row parallel mode, defaults to ParallelMode.PARALLEL_2D_ROW
:type row_parallel_mode: str, optional
:param col_parallel_mode: column parallel mode, defaults to ParallelMode.PARALLEL_2D_COL
:type col_parallel_mode: str, optional
:return: :math:`C = AB`
:rtype: torch.tensor
"""
if row_rank is None:
row_rank = gpc.get_local_rank(col_parallel_mode)
if col_rank is None:
col_rank = gpc.get_local_rank(row_parallel_mode)
data_parallel_rank = 0 if not gpc.is_initialized(
ParallelMode.DATA) else gpc.get_local_rank(ParallelMode.DATA)
pipeline_parallel_rank = 0 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_local_rank(
ParallelMode.PIPELINE)
pipeline_parallel_size = 1 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_world_size(
ParallelMode.PIPELINE)
tensor_parallel_size = summa_dim ** 2
return Matmul_AB_2D(a, b, summa_dim, out_shape, row_rank, col_rank, row_parallel_mode, col_parallel_mode,
data_parallel_rank, pipeline_parallel_rank, pipeline_parallel_size, tensor_parallel_size
)
class Matmul_AB_2D(torch.autograd.Function):
"""Matrix multiplication for :math:`C = AB`
"""
@staticmethod
def forward(ctx: Any,
A: Tensor,
B: Tensor,
summa_dim: int,
out_shape: Tuple[int, ...],
row_rank: int,
col_rank: int,
row_parallel_mode: ParallelMode,
col_parallel_mode: ParallelMode,
data_parallel_rank: int,
pipeline_parallel_rank: int,
pipeline_parallel_size: int,
tensor_parallel_size: int) -> Tensor:
# A: [b / q, s, h / q] -> [(b * s) / q, h / q]
# B: [h / q, s / q]
# C: [b / q, s, s / q] -> [(b * s) / q, s / q]
assert A.shape[-1] == B.shape[-2], \
'Invalid shapes: A={}, B={} for AB.'.format(A.shape, B.shape)
if ctx:
ctx.save_for_backward(A, B)
A_shape = A.shape
A = A.reshape((-1, A_shape[-1]))
B_shape = B.shape
B = B.reshape((-1, B_shape[-1]))
C_shape = (A.shape[0], B.shape[-1])
C = torch.zeros(C_shape, dtype=A.dtype, device=get_current_device())
for i in range(summa_dim):
A_temp = A.clone()
B_temp = B.clone()
src_a = i + summa_dim * row_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
pipeline_parallel_rank * tensor_parallel_size
dist.broadcast(A_temp, src=src_a,
group=gpc.get_group(row_parallel_mode))
src_b = col_rank + summa_dim * i + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
pipeline_parallel_rank * tensor_parallel_size
dist.broadcast(B_temp, src=src_b,
group=gpc.get_group(col_parallel_mode))
torch.addmm(C, A_temp, B_temp, out=C)
out = C.reshape(out_shape)
if ctx:
ctx.summa_dim = summa_dim
ctx.row_rank = row_rank
ctx.col_rank = col_rank
ctx.row_parallel_mode = row_parallel_mode
ctx.col_parallel_mode = col_parallel_mode
ctx.A_shape = A_shape
ctx.B_shape = B_shape
ctx.data_parallel_rank = data_parallel_rank
ctx.pipeline_parallel_rank = pipeline_parallel_rank
ctx.pipeline_parallel_size = pipeline_parallel_size
ctx.tensor_parallel_size = tensor_parallel_size
return out
@staticmethod
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
A, B = ctx.saved_tensors
A_grad = Matmul_ABT_2D.forward(
None,
output_grad, B,
ctx.summa_dim, ctx.A_shape,
ctx.row_rank, ctx.col_rank,
ctx.row_parallel_mode,
ctx.col_parallel_mode,
ctx.data_parallel_rank,
ctx.pipeline_parallel_rank,
ctx.pipeline_parallel_size,
ctx.tensor_parallel_size
)
B_grad = Matmul_ATB_2D.forward(
None,
A, output_grad,
ctx.summa_dim, ctx.B_shape,
ctx.row_rank, ctx.col_rank,
ctx.row_parallel_mode,
ctx.col_parallel_mode,
ctx.data_parallel_rank,
ctx.pipeline_parallel_rank,
ctx.pipeline_parallel_size,
ctx.tensor_parallel_size
)
return A_grad, B_grad, None, None, None, None, None, None, None, None, None, None
class Matmul_ABT_2D(torch.autograd.Function):
"""Matrix multiplication for :math:`C = AB^T`
"""
@staticmethod
def forward(ctx: Any,
A: Tensor,
B: Tensor,
summa_dim: int,
out_shape: Tuple[int, ...],
row_rank: int,
col_rank: int,
row_parallel_mode: ParallelMode,
col_parallel_mode: ParallelMode,
data_parallel_rank: int,
pipeline_parallel_rank: int,
pipeline_parallel_size: int,
tensor_parallel_size: int
) -> Tensor:
assert A.shape[-1] == B.shape[-1], \
'Invalid shapes: A={}, B={} for ABT.'.format(A.shape, B.shape)
if ctx:
ctx.save_for_backward(A, B)
A_shape = A.shape
A = A.reshape((-1, A_shape[-1]))
B_shape = B.shape
B = B.reshape((-1, B_shape[-1]))
C_shape = (A.shape[0], B.shape[0])
C = torch.empty(C_shape, dtype=A.dtype, device=get_current_device())
for i in range(summa_dim):
B_temp = B.clone()
# C_temp = torch.zeros(C_shape, dtype=C.dtype, device=get_current_device())
src_b = col_rank + summa_dim * i + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
pipeline_parallel_rank * tensor_parallel_size
dist.broadcast(B_temp, src=src_b,
group=gpc.get_group(col_parallel_mode))
C_temp = torch.matmul(A, B_temp.transpose(0, 1))
src_c = i + summa_dim * row_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
pipeline_parallel_rank * tensor_parallel_size
dist.reduce(C_temp, dst=src_c,
group=gpc.get_group(row_parallel_mode))
if i == col_rank:
C = C_temp.clone()
out = C.reshape(out_shape)
if ctx:
ctx.summa_dim = summa_dim
ctx.row_rank = row_rank
ctx.col_rank = col_rank
ctx.row_parallel_mode = row_parallel_mode
ctx.col_parallel_mode = col_parallel_mode
ctx.A_shape = A_shape
ctx.B_shape = B_shape
ctx.data_parallel_rank = data_parallel_rank
ctx.pipeline_parallel_rank = pipeline_parallel_rank
ctx.pipeline_parallel_size = pipeline_parallel_size
ctx.tensor_parallel_size = tensor_parallel_size
return out
@staticmethod
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
A, B = ctx.saved_tensors
A_grad = Matmul_AB_2D.forward(
None,
output_grad, B,
ctx.summa_dim, ctx.A_shape,
ctx.row_rank, ctx.col_rank,
ctx.row_parallel_mode,
ctx.col_parallel_mode,
ctx.data_parallel_rank,
ctx.pipeline_parallel_rank,
ctx.pipeline_parallel_size,
ctx.tensor_parallel_size
)
B_grad = Matmul_ATB_2D.forward(
None,
output_grad, A,
ctx.summa_dim, ctx.B_shape,
ctx.row_rank, ctx.col_rank,
ctx.row_parallel_mode,
ctx.col_parallel_mode,
ctx.data_parallel_rank,
ctx.pipeline_parallel_rank,
ctx.pipeline_parallel_size,
ctx.tensor_parallel_size
)
return A_grad, B_grad, None, None, None, None, None, None, None, None, None, None
class Matmul_ATB_2D(torch.autograd.Function):
"""Matrix multiplication for :math:`C = A^TB`
"""
@staticmethod
def forward(ctx: Any,
A: Tensor,
B: Tensor,
summa_dim: int,
out_shape: Tuple[int, ...],
row_rank: int,
col_rank: int,
row_parallel_mode: ParallelMode,
col_parallel_mode: ParallelMode,
data_parallel_rank: int,
pipeline_parallel_rank: int,
pipeline_parallel_size: int,
tensor_parallel_size: int
) -> Tensor:
assert A.shape[-2] == B.shape[-2], \
'Invalid shapes: A={}, B={} for ATB.'.format(A.shape, B.shape)
if ctx:
ctx.save_for_backward(A, B)
A_shape = A.shape
A = A.reshape((-1, A_shape[-1]))
B_shape = B.shape
B = B.reshape((-1, B_shape[-1]))
C_shape = (A.shape[-1], B.shape[-1])
C = torch.empty(C_shape, dtype=A.dtype, device=get_current_device())
for i in range(summa_dim):
A_temp = A.clone()
# C_temp = torch.zeros(C_shape, dtype=C.dtype, device=get_current_device())
src_a = i + summa_dim * row_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
pipeline_parallel_rank * tensor_parallel_size
dist.broadcast(A_temp, src=src_a,
group=gpc.get_group(row_parallel_mode))
C_temp = torch.matmul(A_temp.transpose(0, 1), B)
src_c = col_rank + summa_dim * i + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
pipeline_parallel_rank * tensor_parallel_size
dist.reduce(C_temp, dst=src_c,
group=gpc.get_group(col_parallel_mode))
if i == row_rank:
C = C_temp.clone()
out = C.reshape(out_shape)
if ctx:
ctx.summa_dim = summa_dim
ctx.row_rank = row_rank
ctx.col_rank = col_rank
ctx.row_parallel_mode = row_parallel_mode
ctx.col_parallel_mode = col_parallel_mode
ctx.A_shape = A_shape
ctx.B_shape = B_shape
ctx.data_parallel_rank = data_parallel_rank
ctx.pipeline_parallel_rank = pipeline_parallel_rank
ctx.pipeline_parallel_size = pipeline_parallel_size
ctx.tensor_parallel_size = tensor_parallel_size
return out
@staticmethod
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
A, B = ctx.saved_tensors
A_grad = Matmul_ABT_2D.forward(
None,
B, output_grad,
ctx.summa_dim, ctx.A_shape,
ctx.row_rank, ctx.col_rank,
ctx.row_parallel_mode,
ctx.col_parallel_mode,
ctx.data_parallel_rank,
ctx.pipeline_parallel_rank,
ctx.pipeline_parallel_size,
ctx.tensor_parallel_size
)
B_grad = Matmul_AB_2D.forward(
None,
A, output_grad,
ctx.summa_dim, ctx.B_shape,
ctx.row_rank, ctx.col_rank,
ctx.row_parallel_mode,
ctx.col_parallel_mode,
ctx.data_parallel_rank,
ctx.pipeline_parallel_rank,
ctx.pipeline_parallel_size,
ctx.tensor_parallel_size
)
return A_grad, B_grad, None, None, None, None, None, None, None, None, None, None
class Add_Bias_2D(torch.autograd.Function):
"""Matrix add bias: :math:`C = A + b`
"""
@staticmethod
def forward(ctx: Any,
input: Tensor,
bias: Tensor,
output_size_per_partition: int,
row_rank: int,
col_rank: int,
row_parallel_mode: ParallelMode,
col_parallel_mode: ParallelMode,
skip_bias_add: bool,
data_parallel_rank: int,
pipeline_parallel_rank: int,
pipeline_parallel_size: int,
tensor_parallel_size: int
) -> Tensor:
if row_rank == 0:
bias_temp = bias.clone()
else:
bias_temp = torch.zeros(
output_size_per_partition,
dtype=bias.dtype,
device=get_current_device())
src_rank = col_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
pipeline_parallel_rank * tensor_parallel_size
dist.broadcast(bias_temp, src=src_rank,
group=gpc.get_group(col_parallel_mode))
ctx.row_rank = row_rank
ctx.col_rank = col_rank
ctx.row_parallel_mode = row_parallel_mode
ctx.col_parallel_mode = col_parallel_mode
ctx.bias = skip_bias_add
ctx.data_parallel_rank = data_parallel_rank
ctx.pipeline_parallel_rank = pipeline_parallel_rank
ctx.pipeline_parallel_size = pipeline_parallel_size
ctx.tensor_parallel_size = tensor_parallel_size
if skip_bias_add:
return bias_temp
else:
output = input + bias_temp
return output
@staticmethod
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
row_rank = ctx.row_rank
col_rank = ctx.col_rank
row_parallel_mode = ctx.row_parallel_mode
col_parallel_mode = ctx.col_parallel_mode
data_parallel_rank = ctx.data_parallel_rank
pipeline_parallel_rank = ctx.pipeline_parallel_rank
pipeline_parallel_size = ctx.pipeline_parallel_size
tensor_parallel_size = ctx.tensor_parallel_size
if ctx.bias:
dst_rank = col_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
pipeline_parallel_rank * tensor_parallel_size
dist.reduce(output_grad, dst=dst_rank,
group=gpc.get_group(col_parallel_mode))
if row_rank == 0:
return None, output_grad, None, None, None, None, None, None, None, None, None, None
else:
# for compatibility with zero optimizer, no grad should be None
grad_tmp = torch.zeros_like(output_grad)
return None, grad_tmp, None, None, None, None, None, None, None, None, None, None
else:
reduce_dim = tuple(range(output_grad.ndim - 1))
reduce = torch.sum(output_grad, dim=reduce_dim)
dst_rank = col_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
pipeline_parallel_rank * tensor_parallel_size
dist.reduce(reduce, dst=dst_rank,
group=gpc.get_group(col_parallel_mode))
if row_rank == 0:
return output_grad, reduce, None, None, None, None, None, None, None, None, None, None
else:
# for compatibility with zero optimizer, no grad should be None
reduce_tmp = torch.zeros_like(reduce)
return output_grad, reduce_tmp, None, None, None, None, None, None, None, None, None, None
class _LayerNorm_2D(torch.autograd.Function):
@staticmethod
def forward(ctx: Any,
input: Tensor,
E_x: Tensor,
Var_x: Tensor,
hidden_size: int,
row_parallel_mode: ParallelMode,
col_parallel_mode: ParallelMode) -> Tensor:
input = input - E_x
# in here, input = x - E[x], Var_x = 1 / sqrt(Var[x] + eps)
ctx.normalized_shape = hidden_size
output = input * Var_x
ctx.save_for_backward(output, Var_x)
ctx.row_parallel_mode = row_parallel_mode
ctx.col_parallel_mode = col_parallel_mode
return output
@staticmethod
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
row_parallel_mode = ctx.row_parallel_mode
col_parallel_mode = ctx.col_parallel_mode
x, Var_x = ctx.saved_tensors
# in here, Var_x = 1 / sqrt(Var[x] + eps), x = (x - E[x]) * Var_x
output_grad_sum = torch.sum(output_grad, dim=-1, keepdim=True)
torch.distributed.all_reduce(
output_grad_sum, group=gpc.get_group(row_parallel_mode))
output_grad_sum /= ctx.normalized_shape
output_grad_mul_x_sum = torch.sum(
output_grad * x, dim=-1, keepdim=True)
torch.distributed.all_reduce(
output_grad_mul_x_sum, group=gpc.get_group(row_parallel_mode))
output_grad_mul_x_sum /= ctx.normalized_shape
input_grad = output_grad.clone()
input_grad -= x * output_grad_mul_x_sum
input_grad -= output_grad_sum
input_grad *= Var_x
return input_grad, None, None, None, None, None
# class Sum_2D(torch.autograd.Function):
#
# @staticmethod
# def forward(ctx: Any,
# inputs: Tensor,
# dim: int,
# summa_dim: int,
# row_parallel_mode: ParallelMode,
# keepdim: bool = False) -> Tensor:
# # input: [b/q, s, h/q]
# empty_cache()
# ctx.save_for_backward(inputs)
# # sum: [b/q, s]
# out = torch.sum(inputs, dim=dim, keepdim=keepdim)
# torch.distributed.all_reduce(out, group=gpc.get_group(row_parallel_mode))
# return out
#
# @staticmethod
# def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
# with torch.no_grad():
# inputs = ctx.saved_tensors
# input_grad = torch.ones(inputs.shape, dtype=output_grad.dtype)
# return input_grad, None, None, None, None, None
class _ViT_Split_Input_2D(torch.autograd.Function):
@staticmethod
def forward(ctx: Any,
inputs: Tensor,
batch_size: int,
summa_dim: int,
col_parallel_mode: ParallelMode) -> Tensor:
# inputs: [b, s, h/q]
# output: [b/q, s, h/q]
ctx.BATCH_SIZE = batch_size
ctx.summa_dim = summa_dim
ctx.col_parallel_mode = col_parallel_mode
row_rank = gpc.get_local_rank(col_parallel_mode)
output = torch.chunk(inputs, summa_dim, dim=0)[row_rank]
output = output.clone()
return output
@staticmethod
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
# output_grad: [b/q, s, h/q]
# grads: [b, s, h/q]
grads_shape = (ctx.BATCH_SIZE,) + output_grad.shape[1:]
grads = torch.empty(grads_shape,
dtype=output_grad.dtype,
device=get_current_device())
dist.all_gather(list(grads.chunk(ctx.summa_dim, dim=0)),
output_grad.contiguous(),
group=gpc.get_group(ctx.col_parallel_mode))
return grads, None, None, None

220
colossalai/nn/layer/parallel_2d/_transformer.py

@ -0,0 +1,220 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import math
import torch
from torch import nn as nn, Tensor
from colossalai.nn.layer._common_utils import divide, ACT2FN
from colossalai.nn.layer.parallel_2d._utils import assert_summa_initialization, get_summa_dim_from_env
from colossalai.registry import LAYERS
from .layers import Linear2D, LayerNorm2D
from ..base_layer import ParallelLayer
@LAYERS.register_module
class TransformerMLP2D(ParallelLayer):
"""
MLP will take the input with h hidden state, project it to mlp_ratio * h
hidden dimension, perform nonlinear transformation, and project the
state back into h hidden dimension. At the end, dropout is also
applied.
:param in_features: the size of input tensor
:type in_features: int
:param mlp_ratio: hidden size of MLP divided by embedding dim, defaults to 4.0
:type mlp_ratio: int, optional
:param act_func: activation function, defaults to 'gelu'
:type act_func: str, optional
:param dropout_prob: dropout probability, defaults to 0.
:type dropout_prob: float, optional
:param dtype: dtype of parameters, defaults to None
:type dtype: torch.dtype, optional
:param skip_bias_add: If set to ``True``, it will skip bias add for linear layer, which is preserved for kernel fusion, defaults to False
:type skip_bias_add: bool, optional
"""
def __init__(self,
in_features: int,
mlp_ratio: int = 4.0,
act_func: str = 'gelu',
dropout_prob: float = 0.,
dtype=None,
skip_bias_add: bool = False
):
super().__init__()
assert_summa_initialization()
self.summa_dim = get_summa_dim_from_env()
self.in_features = in_features
self.skip_bias_add = skip_bias_add
# Project to h * mlp_ratio.
self.dense_1 = Linear2D(
in_features,
int(mlp_ratio * in_features),
dtype=dtype,
skip_bias_add=self.skip_bias_add
)
assert act_func in ACT2FN.keys(), f'Invalid value for argument act_func, ' \
f'activation function can only be {list(ACT2FN.keys())}'
self.activation_func = ACT2FN[act_func]
# Project back to h.
self.dense_2 = Linear2D(
int(mlp_ratio * in_features),
in_features,
dtype=dtype,
skip_bias_add=self.skip_bias_add
)
self.dropout = nn.Dropout(dropout_prob)
self.layernorm = LayerNorm2D(in_features, dtype=dtype)
def forward(self, x: Tensor) -> Tensor:
if self.skip_bias_add:
intermediate_output, _ = self.dense_1(x)
else:
intermediate_output = self.dense_1(x)
intermediate_output = self.activation_func(intermediate_output)
if self.skip_bias_add:
output, _ = self.dense_2(intermediate_output)
else:
output = self.dense_2(intermediate_output)
output = self.dropout(output)
output = self.layernorm(x + output)
return output
@LAYERS.register_module
class TransformerSelfAttention2D(ParallelLayer):
"""Self attention layer for 2D parallel Transformer
:param hidden_size: hidden size
:type hidden_size: int
:param num_attention_heads: number of attention heads
:type num_attention_heads: int
:param attention_dropout_prob: dropout probability for attention layer
:type attention_dropout_prob: float
:param hidden_dropout_prob: dropout probability for hidden layer
:type hidden_dropout_prob: float
:param dtype: dtype of parameters, defaults to None
:type dtype: torch.dtype, optional
"""
def __init__(self,
hidden_size: int,
num_attention_heads: int,
attention_dropout_prob: float,
hidden_dropout_prob: float,
dtype=None,
):
super().__init__()
assert_summa_initialization()
self.summa_dim = get_summa_dim_from_env()
self.hidden_size = hidden_size
self.num_attention_heads = divide(num_attention_heads, self.summa_dim)
self.attention_head_size = divide(hidden_size, num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.query_key_value = Linear2D(
hidden_size,
3 * hidden_size,
dtype=dtype,
)
self.attention_dropout = nn.Dropout(attention_dropout_prob)
self.dense = Linear2D(
hidden_size,
hidden_size,
dtype=dtype,
)
self.dropout = nn.Dropout(hidden_dropout_prob)
self.layernorm = LayerNorm2D(
hidden_size,
dtype=dtype)
def forward(self, hidden_states: Tensor, attention_mask: Tensor) -> Tensor:
query_key_value = self.query_key_value(hidden_states)
new_qkv_shape = query_key_value.shape[:-1] + \
(self.num_attention_heads, 3 * self.attention_head_size)
query_key_value = query_key_value.view(new_qkv_shape)
query_key_value = query_key_value.permute((0, 2, 1, 3))
query_layer, key_layer, value_layer = torch.chunk(
query_key_value, 3, dim=-1)
attention_scores = torch.matmul(
query_layer, key_layer.transpose(-1, -2))
attention_scores = attention_scores / \
math.sqrt(self.attention_head_size)
attention_scores = attention_scores + attention_mask
attention_probs = nn.Softmax(dim=-1)(attention_scores)
attention_probs = self.attention_dropout(attention_probs)
context_layer = torch.matmul(attention_probs, value_layer)
context_layer = context_layer.permute((0, 2, 1, 3)).contiguous()
new_context_layer_shape = context_layer.size()[
:-2] + (self.all_head_size,)
context_layer = context_layer.view(*new_context_layer_shape)
output = self.dense(context_layer)
output = self.dropout(output)
attention_output = self.layernorm(hidden_states + output)
return attention_output
@LAYERS.register_module
class TransformerLayer2D(ParallelLayer):
"""Transformer layer which contains a self-attention layer and a MLP layer
:param hidden_size: hidden size
:type hidden_size: int
:param num_attention_heads: number of attention heads
:type num_attention_heads: int
:param act_func: activation function, defaults to 'gelu'
:type act_func: str, optional
:param mlp_ratio: hidden size of MLP divided by embedding dim, defaults to 4.0
:type mlp_ratio: float, optional
:param attention_dropout_prob: dropout probability for attention layer, defaults to 0.
:type attention_dropout_prob: float, optional
:param hidden_dropout_prob: dropout probability for attention layer, defaults to 0.
:type hidden_dropout_prob: float, optional
:param dtype: dtype of parameters, defaults to None
:type dtype: torch.dtype, optional
"""
def __init__(self,
hidden_size: int,
num_attention_heads: int,
act_func: str = 'gelu',
mlp_ratio: float = 4.0,
attention_dropout_prob: float = 0.,
hidden_dropout_prob: float = 0.,
dtype=None,
):
super().__init__()
self.attention = TransformerSelfAttention2D(
hidden_size=hidden_size,
num_attention_heads=num_attention_heads,
attention_dropout_prob=attention_dropout_prob,
hidden_dropout_prob=hidden_dropout_prob,
dtype=dtype,
)
self.mlp = TransformerMLP2D(
in_features=hidden_size,
dropout_prob=hidden_dropout_prob,
act_func=act_func,
mlp_ratio=mlp_ratio,
dtype=dtype,
)
def forward(self, hidden_states: Tensor, attention_mask: Tensor) -> Tensor:
attention_output = self.attention(hidden_states, attention_mask)
output = self.mlp(attention_output)
return output

23
colossalai/nn/layer/parallel_2d/_utils.py

@ -0,0 +1,23 @@
import os
from colossalai.context.parallel_mode import ParallelMode
from colossalai.context.process_group_initializer.initializer_2d import SUMMA_DIM
from colossalai.core import global_context as gpc
def get_summa_dim_from_env() -> int:
try:
summa_dim = os.environ[SUMMA_DIM]
summa_dim = int(summa_dim)
assert summa_dim > 0, 'SUMMA_DIM must be larger than zero'
return summa_dim
except KeyError as e:
raise EnvironmentError('SUMMA_DIM is not found in the current environment, '
'please make sure that you have used the correct process group initializer')
def assert_summa_initialization():
assert gpc.is_initialized(ParallelMode.PARALLEL_2D_COL) and \
gpc.is_initialized(ParallelMode.PARALLEL_2D_ROW), \
'Both TWO_DIMENSION_COL and TWO_DIMENSION_ROW must be initialized by the process group initializer'

391
colossalai/nn/layer/parallel_2d/_vit.py

@ -0,0 +1,391 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import math
import torch
from torch import nn as nn, Tensor, distributed as dist
from colossalai.context import seed, ParallelMode
from colossalai.core import global_context as gpc
from colossalai.nn.layer._common_utils import divide, ACT2FN
from colossalai.nn.layer.parallel_2d._utils import assert_summa_initialization, get_summa_dim_from_env
from colossalai.nn.layer.vanilla_vision_transformer.layers import to_2tuple
from colossalai.registry import LAYERS
from colossalai.utils import checkpoint
from colossalai.utils import get_current_device
from ._operation import _ViT_Split_Input_2D
from .layers import Linear2D
from .._common_utils import set_tensor_parallel_attribute
from ..base_layer import ParallelLayer
@LAYERS.register_module
class ViTMLP2D(ParallelLayer):
"""MLP layer for 2D parallel Vision Transformer
:param in_features: size of each input sample
:type in_features: int
:param mlp_ratio: hidden size of MLP divided by embedding dim
:type mlp_ratio: int
:param act_func: activation function, defaults to 'gelu'
:type act_func: str, optional
:param dropout_prob: dropout probability, defaults to 0.
:type dropout_prob: float, optional
:param dtype: The dtype of parameters, defaults to None
:type dtype: torch.dtype, optional
:param checkpoint: whether to checkpoint the layer, defaults to False
:type checkpoint: bool, optional
"""
def __init__(self,
in_features: int,
mlp_ratio: int,
act_func: str = 'gelu',
dropout_prob: float = 0.,
dtype=None,
checkpoint: bool = False
):
super().__init__()
assert_summa_initialization()
self.summa_dim = get_summa_dim_from_env()
self.in_features = in_features
self.mlp_ratio = mlp_ratio
self.checkpoint = checkpoint
# Project to mlp_ratio * h.
self.dense_1 = Linear2D(
self.in_features,
self.mlp_ratio * self.in_features,
dtype=dtype,
)
self.act = ACT2FN[act_func]
# Project back to h.
self.dense_2 = Linear2D(
self.mlp_ratio * self.in_features,
self.in_features,
dtype=dtype,
)
self.dropout = nn.Dropout(dropout_prob)
def _forward(self, hidden_states: Tensor) -> Tensor:
intermediate_output = self.dense_1(hidden_states)
intermediate_output = self.act(intermediate_output)
with seed(ParallelMode.TENSOR):
intermediate_output = self.dropout(intermediate_output)
output = self.dense_2(intermediate_output)
with seed(ParallelMode.TENSOR):
output = self.dropout(output)
return output
def _checkpoint_forward(self, hidden_states: Tensor) -> Tensor:
return checkpoint(self._forward, hidden_states)
def forward(self, hidden_states: Tensor) -> Tensor:
if self.checkpoint:
return self._checkpoint_forward(hidden_states)
else:
return self._forward(hidden_states)
@LAYERS.register_module
class ViTSelfAttention2D(ParallelLayer):
"""Self-attention layer for 2D parallel Vision Transformer
:param hidden_size: hidden size
:type hidden_size: int
:param num_attention_heads: number of attention heads
:type num_attention_heads: int
:param attention_dropout_prob: dropout probability for attention layers
:type attention_dropout_prob: float
:param hidden_dropout_prob: dropout probability for hidden layers
:type hidden_dropout_prob: float
:param dtype: dtype of parameters, defaults to None
:type dtype: torch.dtype, optional
:param checkpoint: whether to checkpoint the layer, defaults to False
:type checkpoint: bool, optional
"""
def __init__(self,
hidden_size: int,
num_attention_heads: int,
attention_dropout_prob: float,
hidden_dropout_prob: float,
dtype=None,
checkpoint: bool = False
):
super().__init__()
assert_summa_initialization()
self.summa_dim = get_summa_dim_from_env()
self.hidden_size = hidden_size
self.num_attention_heads = divide(num_attention_heads, self.summa_dim)
self.attention_head_size = divide(hidden_size, num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.checkpoint = checkpoint
self.query_key_value = Linear2D(
hidden_size,
3 * hidden_size,
dtype=dtype,
)
self.attention_dropout = nn.Dropout(attention_dropout_prob)
self.dense = Linear2D(
hidden_size,
hidden_size,
dtype=dtype,
)
self.dropout = nn.Dropout(hidden_dropout_prob)
self.softmax = nn.Softmax(dim=-1)
def _forward(self, hidden_states: Tensor) -> Tensor:
query_key_value = self.query_key_value(hidden_states)
new_qkv_shape = query_key_value.shape[:-1] + \
(self.num_attention_heads, 3 * self.attention_head_size)
query_key_value = query_key_value.view(new_qkv_shape)
query_key_value = query_key_value.permute((0, 2, 1, 3))
query_layer, key_layer, value_layer = torch.chunk(
query_key_value, 3, dim=-1)
attention_scores = torch.matmul(
query_layer, key_layer.transpose(-1, -2))
attention_scores = attention_scores / \
math.sqrt(self.attention_head_size)
attention_probs = self.softmax(attention_scores)
with seed(ParallelMode.TENSOR):
attention_probs = self.attention_dropout(attention_probs)
context_layer = torch.matmul(attention_probs, value_layer)
context_layer = context_layer.transpose(1, 2)
new_context_layer_shape = context_layer.size()[
:-2] + (self.all_head_size,)
context_layer = context_layer.reshape(new_context_layer_shape)
output = self.dense(context_layer)
with seed(ParallelMode.TENSOR):
output = self.dropout(output)
return output
def _checkpoint_forward(self, hidden_states: Tensor) -> Tensor:
return checkpoint(self._forward, hidden_states)
def forward(self, hidden_states: Tensor) -> Tensor:
if self.checkpoint:
return self._checkpoint_forward(hidden_states)
else:
return self._forward(hidden_states)
@LAYERS.register_module
class ViTHead2D(ParallelLayer):
"""Output layer for 2D parallel Vision Transformer
:param hidden_size: hidden size
:type hidden_size: int
:param num_classes: number of classes
:type num_classes: int
:param dtype: dtype of parameters, defaults to None
:type dtype: torch.dtype, optional
"""
def __init__(self,
hidden_size,
num_classes,
dtype=None,
):
super().__init__()
assert_summa_initialization()
self.summa_dim = get_summa_dim_from_env()
self.linear = Linear2D(
hidden_size,
num_classes,
dtype=dtype,
)
def forward(self, x: Tensor) -> Tensor:
x = x[:, 0]
x = self.linear(x)
return x
@LAYERS.register_module
class ViTPatchEmbedding2D(ParallelLayer):
""" 2D Image to Patch Embedding
:param img_size: iamge size
:type img_size: int
:param patch_size: patch size
:type patch_size: int
:param embed_dim: dimension of embedding
:type embed_dim: int
:param in_chans: number of channels of input image, defaults to 3
:type in_chans: int, optional
:param flatten: whether to flatten output tensor, defaults to True
:type flatten: bool, optional
"""
def __init__(self,
img_size,
patch_size,
embed_dim,
in_chans=3,
flatten=True):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
assert_summa_initialization()
self.summa_dim = get_summa_dim_from_env()
self.img_size = img_size
self.patch_size = patch_size
self.grid_size = (img_size[0] // patch_size[0],
img_size[1] // patch_size[1])
self.num_patches = self.grid_size[0] * self.grid_size[1]
self.flatten = flatten
self.embed_dim = embed_dim // self.summa_dim
with seed(ParallelMode.TENSOR):
# ensure the partitions are initialized differently
self.proj = nn.Conv2d(in_chans,
self.embed_dim,
kernel_size=patch_size,
stride=patch_size
)
# sync
self._broadcast_conv_params()
self.proj.weight.register_hook(self._sync_grad_during_backward)
self.proj.bias.register_hook(self._sync_grad_during_backward)
def _set_tensor_parallel_attribute(self):
set_tensor_parallel_attribute(self.proj.weight)
set_tensor_parallel_attribute(self.proj.bias)
def _broadcast_conv_params(self) -> None:
self.to(get_current_device())
ranks_in_col = gpc.get_ranks_in_group(ParallelMode.PARALLEL_2D_COL)
dist.broadcast(self.proj.weight, src=ranks_in_col[0],
group=gpc.get_group(ParallelMode.PARALLEL_2D_COL))
dist.broadcast(self.proj.bias, src=ranks_in_col[0],
group=gpc.get_group(ParallelMode.PARALLEL_2D_COL))
def _sync_grad_during_backward(self, grad: Tensor) -> None:
dist.all_reduce(grad, group=gpc.get_group(
ParallelMode.PARALLEL_2D_COL))
grad = grad / self.summa_dim
return grad
def forward(self, x: Tensor) -> Tensor:
B, C, H, W = x.shape
assert H == self.img_size[0] and W == self.img_size[1], \
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
x = self.proj(x)
if self.flatten:
x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
return x
@LAYERS.register_module
class ViTTokenFuser2D(ParallelLayer):
"""
Fuse cls token and pos embedding to the input
:param img_size: image size
:type img_size: int
:param patch_size: patch size
:type patch_size: int
:param embed_dim: dimension of embedding
:type embed_dim: int
:param drop_rate: dropout probability, defaults to 0.
:type drop_rate: float, optional
"""
def __init__(self,
img_size,
patch_size,
embed_dim,
drop_rate=0.
):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
assert_summa_initialization()
self.summa_dim = get_summa_dim_from_env()
self.img_size = img_size
self.patch_size = patch_size
self.grid_size = (img_size[0] // patch_size[0],
img_size[1] // patch_size[1])
self.num_patches = self.grid_size[0] * self.grid_size[1]
self.embed_dim = embed_dim
self.cls_token = nn.Parameter(torch.zeros(
1, 1, self.embed_dim // self.summa_dim))
self.pos_embed = nn.Parameter(torch.zeros(
1, self.num_patches + 1, self.embed_dim // self.summa_dim))
# move to cuda before broadcast
self.to(get_current_device())
# sync param in both forward and backward
_cls_token = self.cls_token.view(-1)
_pos_embed = self.pos_embed.view(-1)
self._param = torch.cat([_cls_token, _pos_embed], dim=0)
self._broadcast_params(self._param)
self._param.register_hook(self._sync_grad_hook)
self.pos_drop = nn.Dropout(p=drop_rate)
self._set_tensor_parallel_attribute()
def _set_tensor_parallel_attribute(self):
set_tensor_parallel_attribute(self.cls_token)
set_tensor_parallel_attribute(self.pos_embed)
def _broadcast_params(self, param) -> None:
" broadcast to all column ranks for data consistency "
ranks_in_col = gpc.get_ranks_in_group(ParallelMode.PARALLEL_2D_COL)
col_group = gpc.get_group(ParallelMode.PARALLEL_2D_COL)
dist.broadcast(param, src=ranks_in_col[0],
group=col_group)
def _sync_grad_hook(self, grad) -> None:
dist.all_reduce(grad, group=gpc.get_group(
ParallelMode.PARALLEL_2D_COL))
grad = grad / self.summa_dim
return grad
def forward(self, x: Tensor) -> Tensor:
# stole cls_tokens impl from Phil Wang, thanks
cls_token = self.cls_token.expand(x.shape[0], -1, -1)
x = torch.cat((cls_token, x), dim=1)
with seed(ParallelMode.TENSOR):
x = self.pos_drop(x + self.pos_embed)
return x
@LAYERS.register_module
class ViTInputSplitter2D(ParallelLayer):
"""Split the input tensor for 2D parallel Vision Transformer
"""
def __init__(self):
super().__init__()
assert_summa_initialization()
self.summa_dim = get_summa_dim_from_env()
def forward(self, x: Tensor) -> Tensor:
batch_size = x.size(0)
return _ViT_Split_Input_2D.apply(
x,
batch_size,
self.summa_dim,
ParallelMode.PARALLEL_2D_COL
)

258
colossalai/nn/layer/parallel_2d/layers.py

@ -0,0 +1,258 @@
import math
import torch
import torch.distributed as dist
from torch import Tensor
from torch.nn import Parameter, init as init
from colossalai.context import seed, ParallelMode
from colossalai.core import global_context as gpc
from colossalai.registry import LAYERS
from colossalai.utils import get_current_device
from ._operation import Matmul_AB_2D, Add_Bias_2D, _LayerNorm_2D
from ._utils import get_summa_dim_from_env, assert_summa_initialization
from .._common_utils import divide, set_tensor_parallel_attribute
from ..base_layer import ParallelLayer
@LAYERS.register_module
class Linear2D(ParallelLayer):
""" Linear layer for 2D parallelism
:param in_features: size of each input sample
:type in_features: int
:param out_features: size of each output sample
:type out_features: int
:param bias: If set to ``False``, the layer will not learn an additive bias, defaults to True
:type bias: bool, optional
:param dtype: The dtype of parameters, defaults to None
:type dtype: torch.dtype, optional
:param skip_bias_add: If set to ``True``, it will skip bias add for linear layer, which is preserved for kernel fusion, defaults to False
:type skip_bias_add: bool, optional
"""
def __init__(self,
in_features: int,
out_features: int,
bias: bool = True,
dtype=None,
skip_bias_add: bool = False
):
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.skip_bias_add = skip_bias_add
# parallel settings
assert_summa_initialization()
self.row_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL)
self.col_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW)
self.summa_dim = get_summa_dim_from_env()
# partitioning dimension
self.input_size_per_partition = divide(
self.in_features, self.summa_dim)
self.hidden_size_per_partition = divide(
self.out_features, self.summa_dim)
# create weight, shape: [k/q, h/q]
factory_kwargs = {'device': get_current_device(), 'dtype': dtype}
self.weight = Parameter(torch.empty(
self.input_size_per_partition,
self.hidden_size_per_partition,
**factory_kwargs))
# create bias, shape: [h/q]
if bias:
self.bias = Parameter(torch.empty(
self.hidden_size_per_partition,
**factory_kwargs))
else:
self.register_parameter('bias', None)
# initialize parameters
self.reset_parameters()
self._set_tensor_parallel_attributes()
def _set_tensor_parallel_attributes(self):
set_tensor_parallel_attribute(self.weight)
if self.bias is not None:
set_tensor_parallel_attribute(self.bias)
def reset_parameters(self) -> None:
# setting
fan_in = self.in_features
a = math.sqrt(5)
nonlinearity = 'leaky_relu'
# init weight
std = init.calculate_gain(nonlinearity, a) / math.sqrt(fan_in)
bound = math.sqrt(3.0) * std
with seed(ParallelMode.TENSOR):
init.uniform_(self.weight, -bound, bound)
# init bias
if self.bias is not None:
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
with seed(ParallelMode.TENSOR):
init.uniform_(self.bias, -bound, bound)
def forward(self, x: Tensor) -> Tensor:
# input: [m/q, n/q, k/q]
# output: [m/q, n/q, h/q]
out_shape = x.shape[:-1] + (self.hidden_size_per_partition,)
output = Matmul_AB_2D.apply(
x,
self.weight,
self.summa_dim,
out_shape,
self.row_rank,
self.col_rank,
ParallelMode.PARALLEL_2D_ROW,
ParallelMode.PARALLEL_2D_COL,
self.data_parallel_rank,
self.pipeline_parallel_rank,
self.pipeline_parallel_size,
self.tensor_parallel_size)
if self.bias is not None:
if self.skip_bias_add:
bias = Add_Bias_2D.apply(
None,
self.bias,
self.hidden_size_per_partition,
self.row_rank,
self.col_rank,
ParallelMode.PARALLEL_2D_ROW,
ParallelMode.PARALLEL_2D_COL,
True,
self.data_parallel_rank,
self.pipeline_parallel_rank,
self.pipeline_parallel_size,
self.tensor_parallel_size
)
return output, bias
else:
output = Add_Bias_2D.apply(
output,
self.bias,
self.hidden_size_per_partition,
self.row_rank,
self.col_rank,
ParallelMode.PARALLEL_2D_ROW,
ParallelMode.PARALLEL_2D_COL,
False,
self.data_parallel_rank,
self.pipeline_parallel_rank,
self.pipeline_parallel_size,
self.tensor_parallel_size
)
return output
else:
return output
@LAYERS.register_module
class LayerNorm2D(ParallelLayer):
r"""Layer Normalization for 2D parallelism
:param normalized_shape: input shape from an expected input
of size. :math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1] \times \ldots \times \text{normalized_shape}[-1]]`
If a single integer is used, it is treated as a singleton list, and this module will
normalize over the last dimension which is expected to be of that specific size.
:type normalized_shape: int
:param eps: a value added to the denominator for numerical stability, defaults to 1e-05
:type eps: float, optional
:param dtype: The dtype of parameters, defaults to None
:type dtype: torch.dtype, optional
"""
def __init__(self,
normalized_shape: int,
eps: float = 1e-05,
dtype=None
):
super().__init__()
# layer norm config
self.normalized_shape = normalized_shape
self.variance_epsilon = eps
# parallel setting
assert_summa_initialization()
self.row_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL)
self.col_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW)
self.summa_dim = get_summa_dim_from_env()
# partitioning dimension
self.partitioned_partition = divide(normalized_shape, self.summa_dim)
# create parameters
factory_kwargs = {'device': get_current_device(), 'dtype': dtype}
if self.row_rank == 0:
self.gamma = Parameter(torch.ones(
self.partitioned_partition,
**factory_kwargs))
self.beta = Parameter(torch.zeros(
self.partitioned_partition,
**factory_kwargs))
else:
self.gamma = Parameter(torch.tensor(
1.0,
requires_grad=True,
**factory_kwargs))
self.beta = Parameter(torch.tensor(
1.0,
requires_grad=True,
**factory_kwargs))
self._set_tensor_parallel_attributes()
def _set_tensor_parallel_attributes(self):
set_tensor_parallel_attribute(self.gamma)
set_tensor_parallel_attribute(self.beta)
def forward(self, x: Tensor) -> Tensor:
with torch.no_grad():
E_x = torch.sum(x, dim=-1, keepdim=True) # [b/q, s, 1]
torch.distributed.all_reduce(
E_x, group=gpc.get_group(ParallelMode.PARALLEL_2D_ROW))
E_x /= self.normalized_shape
# Var_x in the block below is the sum of input^2
Var_x = torch.sum(x * x, dim=-1, keepdim=True) # [b/q, s, 1]
torch.distributed.all_reduce(
Var_x, group=gpc.get_group(ParallelMode.PARALLEL_2D_ROW))
Var_x /= self.normalized_shape
Var_x = Var_x - E_x * E_x # variance of x [b/q, s, 1]
# this time 1/sqrt(Var_x + epsilon)
Var_x = 1.0 / torch.sqrt(Var_x + self.variance_epsilon)
output = _LayerNorm_2D.apply(x, E_x, Var_x, self.normalized_shape,
ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL)
bias = Add_Bias_2D.apply(
None, self.beta, self.partitioned_partition,
self.row_rank, self.col_rank,
ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL,
True,
self.data_parallel_rank,
self.pipeline_parallel_rank,
self.pipeline_parallel_size,
self.tensor_parallel_size
)
scale = Add_Bias_2D.apply(
None, self.gamma, self.partitioned_partition,
self.row_rank, self.col_rank,
ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL,
True,
self.data_parallel_rank,
self.pipeline_parallel_rank,
self.pipeline_parallel_size,
self.tensor_parallel_size
)
output = torch.addcmul(bias, scale, output)
return output

13
colossalai/nn/layer/parallel_2p5d/__init__.py

@ -0,0 +1,13 @@
from ._operation import Matmul_AB_2p5D, Matmul_ABT_2p5D, Matmul_ATB_2p5D, Sum_2p5D, Add_Bias_2p5D
from ._transformer import TransformerMLP2p5D, TransformerSelfAttention2p5D, TransformerLayer2p5D
from ._vit import (ViTMLP2p5D, ViTSelfAttention2p5D, ViTHead2p5D, ViTPatchEmbedding2p5D, ViTTokenFuser2p5D,
ViTInputSplitter2p5D)
from .layers import Linear2p5D, LayerNorm2p5D
__all__ = [
'Matmul_AB_2p5D', 'Matmul_ABT_2p5D', 'Matmul_ATB_2p5D', 'Sum_2p5D', 'Add_Bias_2p5D',
'TransformerMLP2p5D', 'TransformerSelfAttention2p5D', 'TransformerLayer2p5D',
'ViTMLP2p5D', 'ViTSelfAttention2p5D', 'ViTHead2p5D', 'ViTPatchEmbedding2p5D', 'ViTTokenFuser2p5D',
'ViTInputSplitter2p5D',
'Linear2p5D', 'LayerNorm2p5D'
]

535
colossalai/nn/layer/parallel_2p5d/_operation.py

@ -0,0 +1,535 @@
from typing import Any, Tuple
import torch
import torch.distributed as dist
from torch import Tensor
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.utils import get_current_device, empty_cache
def get_parallel_group(parallel_mode: ParallelMode):
return gpc.get_group(parallel_mode)
def get_global_rank():
return gpc.get_global_rank()
def get_parallel_rank(parallel_mode: ParallelMode):
return gpc.get_local_rank(parallel_mode)
class Matmul_AB_2p5D(torch.autograd.Function):
"""Matrix multiplication for :math:`C = AB`
"""
@staticmethod
def forward(ctx: Any,
A: Tensor,
B: Tensor,
tesseract_dim: int,
tesseract_dep: int,
out_shape: Tuple[int, ...],
row_rank: int,
col_rank: int,
dep_rank: int,
row_parallel_mode: ParallelMode,
col_parallel_mode: ParallelMode,
dep_parallel_mode: ParallelMode,
data_parallel_rank: int,
pipeline_parallel_rank: int,
pipeline_parallel_size: int,
tensor_parallel_size: int) -> Tensor:
# A: [b / dq, s, h / q] -> [(b * s) / dq, h / q]
# B: [h / dq, s / q]
# C: [b / dq, s, s / q] -> [(b * s) / dq, s / q]
assert A.shape[-1] == B.shape[-2], \
'Invalid shapes: A={}, B={} for AB.'.format(A.shape, B.shape)
empty_cache()
if ctx:
ctx.save_for_backward(A, B)
A_shape = A.shape
A = A.reshape((-1, A_shape[-1]))
B_shape = B.shape
B = B.reshape((-1, B_shape[-1]))
C_shape = (A.shape[0], B.shape[-1])
C = torch.zeros(C_shape, dtype=A.dtype, device=get_current_device())
for i in range(tesseract_dim):
A_temp = A.clone()
B_temp = B.clone()
src_a = i + row_rank * tesseract_dim + dep_rank * (
tesseract_dim ** 2) + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + pipeline_parallel_rank * tensor_parallel_size
dist.broadcast(A_temp, src=src_a,
group=get_parallel_group(row_parallel_mode))
src_b = col_rank + i * tesseract_dim + dep_rank * (
tesseract_dim ** 2) + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + pipeline_parallel_rank * tensor_parallel_size
dist.broadcast(B_temp, src=src_b,
group=get_parallel_group(col_parallel_mode))
torch.addmm(C, A_temp, B_temp, out=C)
out = C.reshape(out_shape)
if ctx:
ctx.tesseract_dim = tesseract_dim
ctx.tesseract_dep = tesseract_dep
ctx.row_rank = row_rank
ctx.col_rank = col_rank
ctx.dep_rank = dep_rank
ctx.row_parallel_mode = row_parallel_mode
ctx.col_parallel_mode = col_parallel_mode
ctx.dep_parallel_mode = dep_parallel_mode
ctx.A_shape = A_shape
ctx.B_shape = B_shape
ctx.data_parallel_rank = data_parallel_rank
ctx.pipeline_parallel_rank = pipeline_parallel_rank
ctx.pipeline_parallel_size = pipeline_parallel_size
ctx.tensor_parallel_size = tensor_parallel_size
return out
@staticmethod
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
A, B = ctx.saved_tensors
A_grad = Matmul_ABT_2p5D.forward(
None,
output_grad, B,
ctx.tesseract_dim, ctx.tesseract_dep, ctx.A_shape,
ctx.row_rank, ctx.col_rank, ctx.dep_rank,
ctx.row_parallel_mode,
ctx.col_parallel_mode,
ctx.dep_parallel_mode,
ctx.data_parallel_rank,
ctx.pipeline_parallel_rank,
ctx.pipeline_parallel_size,
ctx.tensor_parallel_size
)
B_grad = Matmul_ATB_2p5D.forward(
None,
A, output_grad,
ctx.tesseract_dim, ctx.tesseract_dep, ctx.B_shape,
ctx.row_rank, ctx.col_rank, ctx.dep_rank,
ctx.row_parallel_mode,
ctx.col_parallel_mode,
ctx.dep_parallel_mode,
ctx.data_parallel_rank,
ctx.pipeline_parallel_rank,
ctx.pipeline_parallel_size,
ctx.tensor_parallel_size
)
return A_grad, B_grad, None, None, None, None, None, None, None, None, None, None, None, None, None
class Matmul_ABT_2p5D(torch.autograd.Function):
"""Matrix multiplication for :math:`C = AB^T`
"""
@staticmethod
def forward(ctx: Any,
A: Tensor,
B: Tensor,
tesseract_dim: int,
tesseract_dep: int,
out_shape: Tuple[int, ...],
row_rank: int,
col_rank: int,
dep_rank: int,
row_parallel_mode: ParallelMode,
col_parallel_mode: ParallelMode,
dep_parallel_mode: ParallelMode,
data_parallel_rank: int,
pipeline_parallel_rank: int,
pipeline_parallel_size: int,
tensor_parallel_size: int
) -> Tensor:
assert A.shape[-1] == B.shape[-1], \
'Invalid shapes: A={}, B={} for ABT.'.format(A.shape, B.shape)
empty_cache()
if ctx:
ctx.save_for_backward(A, B)
A_shape = A.shape
A = A.reshape((-1, A_shape[-1]))
B_shape = B.shape
B = B.reshape((-1, B_shape[-1]))
C_shape = (A.shape[0], B.shape[0])
C = torch.empty(C_shape, dtype=A.dtype, device=get_current_device())
for i in range(tesseract_dim):
B_temp = B.clone()
src_b = col_rank + i * tesseract_dim + dep_rank * (
tesseract_dim ** 2) + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
pipeline_parallel_rank * tensor_parallel_size
dist.broadcast(B_temp, src=src_b, group=gpc.get_group(col_parallel_mode))
C_temp = torch.matmul(A, B_temp.transpose(0, 1))
src_c = i + row_rank * tesseract_dim + dep_rank * (
tesseract_dim ** 2) + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
pipeline_parallel_rank * tensor_parallel_size
dist.reduce(C_temp, dst=src_c, group=gpc.get_group(row_parallel_mode))
if i == col_rank:
C = C_temp.clone()
out = C.reshape(out_shape)
if ctx:
ctx.tesseract_dim = tesseract_dim
ctx.tesseract_dep = tesseract_dep
ctx.row_rank = row_rank
ctx.col_rank = col_rank
ctx.dep_rank = dep_rank
ctx.row_parallel_mode = row_parallel_mode
ctx.col_parallel_mode = col_parallel_mode
ctx.dep_parallel_mode = dep_parallel_mode
ctx.A_shape = A_shape
ctx.B_shape = B_shape
ctx.data_parallel_rank = data_parallel_rank
ctx.pipeline_parallel_rank = pipeline_parallel_rank
ctx.pipeline_parallel_size = pipeline_parallel_size
ctx.tensor_parallel_size = tensor_parallel_size
return out
@staticmethod
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
A, B = ctx.saved_tensors
A_grad = Matmul_AB_2p5D.forward(
None,
output_grad, B,
ctx.tesseract_dim, ctx.tesseract_dep, ctx.A_shape,
ctx.row_rank, ctx.col_rank, ctx.dep_rank,
ctx.row_parallel_mode,
ctx.col_parallel_mode,
ctx.dep_parallel_mode,
ctx.data_parallel_rank,
ctx.pipeline_parallel_rank,
ctx.pipeline_parallel_size,
ctx.tensor_parallel_size
)
B_grad = Matmul_ATB_2p5D.forward(
None,
output_grad, A,
ctx.tesseract_dim, ctx.tesseract_dep, ctx.B_shape,
ctx.row_rank, ctx.col_rank, ctx.dep_rank,
ctx.row_parallel_mode,
ctx.col_parallel_mode,
ctx.dep_parallel_mode,
ctx.data_parallel_rank,
ctx.pipeline_parallel_rank,
ctx.pipeline_parallel_size,
ctx.tensor_parallel_size
)
return A_grad, B_grad, None, None, None, None, None, None, None, None, None, None, None, None, None
class Matmul_ATB_2p5D(torch.autograd.Function):
"""Matrix multiplication for :math:`C = A^TB`
"""
@staticmethod
def forward(ctx: Any,
A: Tensor,
B: Tensor,
tesseract_dim: int,
tesseract_dep: int,
out_shape: Tuple[int, ...],
row_rank: int,
col_rank: int,
dep_rank: int,
row_parallel_mode: ParallelMode,
col_parallel_mode: ParallelMode,
dep_parallel_mode: ParallelMode,
data_parallel_rank: int,
pipeline_parallel_rank: int,
pipeline_parallel_size: int,
tensor_parallel_size: int):
assert A.shape[-2] == B.shape[-2], \
'Invalid shapes: A={}, B={} for ATB.'.format(A.shape, B.shape)
empty_cache()
if ctx:
ctx.save_for_backward(A, B)
A_shape = A.shape
A = A.reshape((-1, A_shape[-1]))
B_shape = B.shape
B = B.reshape((-1, B_shape[-1]))
C_shape = (A.shape[-1], B.shape[-1])
C = torch.empty(C_shape, dtype=A.dtype, device=get_current_device())
for i in range(tesseract_dim):
A_temp = A.clone()
src_a = i + row_rank * tesseract_dim + dep_rank * (
tesseract_dim ** 2) + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
pipeline_parallel_rank * tensor_parallel_size
dist.broadcast(A_temp, src=src_a,
group=get_parallel_group(row_parallel_mode))
C_temp = torch.matmul(A_temp.transpose(0, 1), B)
src_c = col_rank + i * tesseract_dim + dep_rank * (
tesseract_dim ** 2) + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
pipeline_parallel_rank * tensor_parallel_size
dist.reduce(C_temp, dst=src_c,
group=get_parallel_group(col_parallel_mode))
if i == row_rank:
C = C_temp.clone()
out = C.reshape(out_shape)
if ctx:
ctx.tesseract_dim = tesseract_dim
ctx.tesseract_dep = tesseract_dep
ctx.row_rank = row_rank
ctx.col_rank = col_rank
ctx.dep_rank = dep_rank
ctx.row_parallel_mode = row_parallel_mode
ctx.col_parallel_mode = col_parallel_mode
ctx.dep_parallel_mode = dep_parallel_mode
ctx.A_shape = A_shape
ctx.B_shape = B_shape
ctx.data_parallel_rank = data_parallel_rank
ctx.pipeline_parallel_rank = pipeline_parallel_rank
ctx.pipeline_parallel_size = pipeline_parallel_size
ctx.tensor_parallel_size = tensor_parallel_size
return out
@staticmethod
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
A, B = ctx.saved_tensors
A_grad = Matmul_ABT_2p5D.forward(
None,
B, output_grad,
ctx.tesseract_dim, ctx.tesseract_dep, ctx.A_shape,
ctx.row_rank, ctx.col_rank, ctx.dep_rank,
ctx.row_parallel_mode,
ctx.col_parallel_mode,
ctx.dep_parallel_mode,
ctx.data_parallel_rank,
ctx.pipeline_parallel_rank,
ctx.pipeline_parallel_size,
ctx.tensor_parallel_size
)
B_grad = Matmul_AB_2p5D.forward(
None,
A, output_grad,
ctx.tesseract_dim, ctx.tesseract_dep, ctx.B_shape,
ctx.row_rank, ctx.col_rank, ctx.dep_rank,
ctx.row_parallel_mode,
ctx.col_parallel_mode,
ctx.dep_parallel_mode,
ctx.data_parallel_rank,
ctx.pipeline_parallel_rank,
ctx.pipeline_parallel_size,
ctx.tensor_parallel_size
)
return A_grad, B_grad, None, None, None, None, None, None, None, None, None, None, None, None, None
class Add_Bias_2p5D(torch.autograd.Function):
"""Matrix add bias: :math:`C = A + b`
"""
@staticmethod
def forward(ctx: Any,
input: Tensor,
bias: Tensor,
output_size_per_partition: int,
tesseract_dim: int,
tesseract_dep: int,
row_rank: int,
col_rank: int,
dep_rank: int,
row_parallel_mode: ParallelMode,
col_parallel_mode: ParallelMode,
dep_parallel_mode: ParallelMode,
skip_bias_add: bool,
data_parallel_rank: int,
pipeline_parallel_rank: int,
pipeline_parallel_size: int,
tensor_parallel_size: int
) -> Tensor:
if row_rank == 0:
bias_temp = bias.clone()
else:
bias_temp = torch.zeros(
output_size_per_partition,
dtype=bias.dtype,
device=get_current_device())
src_rank = col_rank + dep_rank * (
tesseract_dim ** 2) + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
pipeline_parallel_rank * tensor_parallel_size
dist.broadcast(bias_temp, src=src_rank, group=get_parallel_group(col_parallel_mode))
ctx.row_rank = row_rank
ctx.col_rank = col_rank
ctx.dep_rank = dep_rank
ctx.tesseract_dim = tesseract_dim
ctx.tesseract_dep = tesseract_dep
ctx.row_parallel_mode = row_parallel_mode
ctx.col_parallel_mode = col_parallel_mode
ctx.dep_parallel_mode = dep_parallel_mode
ctx.bias = skip_bias_add
ctx.data_parallel_rank = data_parallel_rank
ctx.pipeline_parallel_rank = pipeline_parallel_rank
ctx.pipeline_parallel_size = pipeline_parallel_size
ctx.tensor_parallel_size = tensor_parallel_size
if skip_bias_add:
return bias_temp
else:
output = input + bias_temp
return output
@staticmethod
def backward(ctx, output_grad):
row_rank = ctx.row_rank
col_rank = ctx.col_rank
dep_rank = ctx.dep_rank
tesseract_dim = ctx.tesseract_dim
tesseract_dep = ctx.tesseract_dep
row_parallel_mode = ctx.row_parallel_mode
col_parallel_mode = ctx.col_parallel_mode
dep_parallel_mode = ctx.dep_parallel_mode
data_parallel_rank = ctx.data_parallel_rank
pipeline_parallel_rank = ctx.pipeline_parallel_rank
pipeline_parallel_size = ctx.pipeline_parallel_size
tensor_parallel_size = ctx.tensor_parallel_size
if ctx.bias:
dst_rank = col_rank + dep_rank * (
tesseract_dim ** 2) + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
pipeline_parallel_rank * tensor_parallel_size
dist.reduce(output_grad, dst=dst_rank, group=get_parallel_group(col_parallel_mode))
if row_rank == 0:
return None, output_grad, None, None, None, None, None, None, None, None, None, None, None, None, None, None
else:
grad_tmp = torch.zeros_like(output_grad)
return None, grad_tmp, None, None, None, None, None, None, None, None, None, None, None, None, None, None
else:
reduce_dim = tuple(range(output_grad.ndim - 1))
reduce = torch.sum(output_grad, dim=reduce_dim)
dst_rank = col_rank + dep_rank * (
tesseract_dim ** 2) + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
pipeline_parallel_rank * tensor_parallel_size
dist.reduce(reduce, dst=dst_rank, group=get_parallel_group(col_parallel_mode))
if row_rank == 0:
return output_grad, reduce, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None
else:
reduce_tmp = torch.zeros_like(reduce)
return output_grad, reduce_tmp, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None
class _LayerNorm_2p5D(torch.autograd.Function):
@staticmethod
def forward(ctx: Any,
input: Tensor,
E_x: Tensor,
Var_x: Tensor,
hidden_size: int,
row_parallel_mode: ParallelMode,
col_parallel_mode: ParallelMode,
dep_parallel_mode: ParallelMode) -> Tensor:
input = input - E_x
# in here, input = x - E[x], Var_x = 1 / sqrt(Var[x] + eps)
ctx.hidden_size = hidden_size
output = input * Var_x
ctx.save_for_backward(output, Var_x)
ctx.row_parallel_mode = row_parallel_mode
ctx.col_parallel_mode = col_parallel_mode
ctx.dep_parallel_mode = dep_parallel_mode
return output
@staticmethod
def backward(ctx, output_grad):
row_parallel_mode = ctx.row_parallel_mode
col_parallel_mode = ctx.col_parallel_mode
dep_parallel_mode = ctx.dep_parallel_mode
x, Var_x = ctx.saved_tensors
# in here, Var_x = 1 / sqrt(Var[x] + eps), x = (x - E[x]) * Var_x
with torch.no_grad():
output_grad_sum = torch.sum(output_grad, dim=-1, keepdim=True)
torch.distributed.all_reduce(
output_grad_sum, group=get_parallel_group(row_parallel_mode))
output_grad_sum /= ctx.hidden_size
output_grad_mul_x_sum = torch.sum(
output_grad * x, dim=-1, keepdim=True)
torch.distributed.all_reduce(
output_grad_mul_x_sum, group=get_parallel_group(row_parallel_mode))
output_grad_mul_x_sum /= ctx.hidden_size
input_grad = output_grad.clone()
input_grad -= x * output_grad_mul_x_sum
input_grad -= output_grad_sum
input_grad *= Var_x
return input_grad, None, None, None, None, None, None
class Sum_2p5D(torch.autograd.Function):
"""Compute the sum of input tensors
"""
@staticmethod
def forward(ctx,
inputs,
dim,
tesseract_dim,
row_parallel_mode,
keepdim=False):
# input: [b/q, s, h/q]
empty_cache()
ctx.save_for_backward(inputs)
# sum: [b/q, s]
out = torch.sum(inputs, dim=dim, keepdim=keepdim)
torch.distributed.all_reduce(
out, group=gpc.get_group(row_parallel_mode))
return out
@staticmethod
def backward(ctx, output_grad):
with torch.no_grad():
inputs = ctx.saved_tensors
input_grad = torch.ones(inputs.shape, dtype=output_grad.dtype)
return input_grad, None, None, None, None, None
class _ViT_Split_2p5D(torch.autograd.Function):
@staticmethod
def forward(ctx, inputs, batch_size,
tesseract_dim, tesseract_dep,
xz_parallel_mode):
# inputs: [b, s, h/q]
# output: [b/dq, s, h/q]
empty_cache()
ctx.batch_size = batch_size
ctx.tesseract_dim = tesseract_dim
ctx.tesseract_dep = tesseract_dep
ctx.xz_parallel_mode = xz_parallel_mode
xz_rank = gpc.get_local_rank(xz_parallel_mode)
output = torch.chunk(inputs, tesseract_dep *
tesseract_dim, dim=0)[xz_rank]
output = output.clone()
return output
@staticmethod
def backward(ctx, output_grad):
# output_grad: [b/dq, s, h/q]
# grads: [b, s, h/q]
# *
grads_shape = (ctx.batch_size,) + output_grad.shape[1:]
grads = torch.empty(grads_shape,
dtype=output_grad.dtype,
device=get_current_device())
dist.all_gather(list(grads.chunk(ctx.tesseract_dim * ctx.tesseract_dep, dim=0)),
output_grad.contiguous(),
group=get_parallel_group(ctx.xz_parallel_mode))
return grads, None, None, None, None

206
colossalai/nn/layer/parallel_2p5d/_transformer.py

@ -0,0 +1,206 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import math
import torch
from torch import nn as nn, Tensor
from colossalai.nn.layer._common_utils import divide
from colossalai.registry import LAYERS
from ._utils import assert_tesseract_initialization, \
get_tesseract_dim_dep_from_env
from .layers import Linear2p5D, LayerNorm2p5D
from .._common_utils import ACT2FN
@LAYERS.register_module
class TransformerMLP2p5D(nn.Module):
"""
MLP will take the input with h hidden state, project it to mlp_ratio * h
hidden dimension, perform nonlinear transformation, and project the
state back into h hidden dimension. At the end, dropout is also
applied.
:param in_features: the size of input tensor
:type in_features: int
:param mlp_ratio: hidden size of MLP divided by embedding dim, defaults to 4.0
:type mlp_ratio: int, optional
:param act_func: activation function, defaults to 'gelu'
:type act_func: str, optional
:param dropout_prob: dropout probability, defaults to 0.
:type dropout_prob: float, optional
:param dtype: dtype of parameters, defaults to None
:type dtype: torch.dtype, optional
"""
def __init__(self,
in_features: int,
mlp_ratio: int,
act_func: str = 'gelu',
dropout_prob: float = 0.,
dtype=None,
):
super().__init__()
assert_tesseract_initialization()
self.tesseract_dim, self.tesseract_dep = get_tesseract_dim_dep_from_env()
self.in_features = in_features
# Project to h * mlp_ratio.
self.dense_1 = Linear2p5D(
in_features,
mlp_ratio * in_features,
dtype=dtype
)
assert act_func in ACT2FN.keys(), f'Invalid value for argument act_func, ' \
f'activation function can only be {list(ACT2FN.keys())}'
self.activation_func = ACT2FN[act_func]
# Project back to h.
self.dense_2 = Linear2p5D(
mlp_ratio * in_features,
in_features,
dtype=dtype
)
self.dropout = nn.Dropout(dropout_prob)
self.layernorm = LayerNorm2p5D(in_features, dtype=dtype)
def forward(self, x: Tensor) -> Tensor:
intermediate_output = self.dense_1(x)
intermediate_output = self.activation_func(intermediate_output)
output = self.dense_2(intermediate_output)
output = self.dropout(output)
output = self.layernorm(x + output)
return output
@LAYERS.register_module
class TransformerSelfAttention2p5D(nn.Module):
"""Self attention layer for 2.5D parallel Transformer
:param hidden_size: hidden size
:type hidden_size: int
:param num_attention_heads: number of attention heads
:type num_attention_heads: int
:param attention_dropout_prob: dropout probability for attention layer
:type attention_dropout_prob: float
:param hidden_dropout_prob: dropout probability for hidden layer
:type hidden_dropout_prob: float
:param dtype: dtype of parameters, defaults to None
:type dtype: torch.dtype, optional
"""
def __init__(self,
hidden_size,
num_attention_heads,
attention_dropout_prob,
hidden_dropout_prob,
dtype=None,
):
super().__init__()
assert_tesseract_initialization()
self.tesseract_dim, self.tesseract_dep = get_tesseract_dim_dep_from_env()
self.hidden_size = hidden_size
self.num_attention_heads = divide(
num_attention_heads, self.tesseract_dim) # *
self.attention_head_size = divide(hidden_size, num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.query_key_value = Linear2p5D(
hidden_size,
3 * hidden_size,
dtype=dtype,
)
self.attention_dropout = nn.Dropout(attention_dropout_prob)
self.dense = Linear2p5D(
hidden_size,
hidden_size,
dtype=dtype,
)
self.dropout = nn.Dropout(hidden_dropout_prob)
self.layernorm = LayerNorm2p5D(
hidden_size,
dtype=dtype)
def forward(self, hidden_states: Tensor, attention_mask: Tensor) -> Tensor:
query_key_value = self.query_key_value(hidden_states)
new_qkv_shape = query_key_value.shape[:-1] + \
(self.num_attention_heads, 3 * self.attention_head_size)
query_key_value = query_key_value.view(new_qkv_shape)
query_key_value = query_key_value.permute((0, 2, 1, 3))
query_layer, key_layer, value_layer = torch.chunk(
query_key_value, 3, dim=-1)
attention_scores = torch.matmul(
query_layer, key_layer.transpose(-1, -2))
attention_scores = attention_scores / \
math.sqrt(self.attention_head_size)
attention_scores = attention_scores + attention_mask
attention_probs = nn.Softmax(dim=-1)(attention_scores)
attention_probs = self.attention_dropout(attention_probs)
context_layer = torch.matmul(attention_probs, value_layer)
context_layer = context_layer.permute((0, 2, 1, 3)).contiguous()
new_context_layer_shape = context_layer.size()[
:-2] + (self.all_head_size,)
context_layer = context_layer.view(*new_context_layer_shape)
output = self.dense(context_layer)
output = self.dropout(output)
attention_output = self.layernorm(hidden_states + output)
return attention_output
@LAYERS.register_module
class TransformerLayer2p5D(nn.Module):
"""Transformer layer which contains a self-attention layer and a MLP layer
:param hidden_size: hidden size
:type hidden_size: int
:param num_attention_heads: number of attention heads
:type num_attention_heads: int
:param act_func: activation function, defaults to 'gelu'
:type act_func: str, optional
:param mlp_ratio: hidden size of MLP divided by embedding dim, defaults to 4.0
:type mlp_ratio: float, optional
:param attention_dropout_prob: dropout probability for attention layer, defaults to 0.
:type attention_dropout_prob: float, optional
:param hidden_dropout_prob: dropout probability for attention layer, defaults to 0.
:type hidden_dropout_prob: float, optional
:param dtype: dtype of parameters, defaults to None
:type dtype: torch.dtype, optional
"""
def __init__(self,
hidden_size,
num_attention_heads,
act_func='gelu',
mlp_ratio=4,
attention_dropout_prob: float = 0.,
hidden_dropout_prob: float = 0.,
dtype=None,
):
super().__init__()
self.attention = TransformerSelfAttention2p5D(
hidden_size=hidden_size,
num_attention_heads=num_attention_heads,
attention_dropout_prob=attention_dropout_prob,
hidden_dropout_prob=hidden_dropout_prob,
dtype=dtype,
)
self.mlp = TransformerMLP2p5D(
in_features=hidden_size,
dropout_prob=hidden_dropout_prob,
act_func=act_func,
mlp_ratio=mlp_ratio,
dtype=dtype,
)
def forward(self, hidden_states: Tensor, attention_mask: Tensor) -> Tensor:
attention_output = self.attention(hidden_states, attention_mask)
output = self.mlp(attention_output)
return output

25
colossalai/nn/layer/parallel_2p5d/_utils.py

@ -0,0 +1,25 @@
import os
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
def get_tesseract_dim_dep_from_env():
try:
tesseract_dim = int(os.environ['TESSERACT_DIM'])
tesseract_dep = int(os.environ['TESSERACT_DEP'])
assert tesseract_dim > 0, 'TESSERACT_DIM must be larger than zero'
assert tesseract_dep > 0, 'TESSERACT_DEP must be larger than zero'
return tesseract_dim, tesseract_dep
except KeyError as e:
raise EnvironmentError('TESSERACT_DIM or TESSERACT_DEP is not found in the current environment, '
'please make sure that you have used the correct process group initializer')
def assert_tesseract_initialization():
assert gpc.is_initialized(ParallelMode.PARALLEL_2P5D_COL) and \
gpc.is_initialized(ParallelMode.PARALLEL_2P5D_ROW) and \
gpc.is_initialized(ParallelMode.PARALLEL_2P5D_DEP) and \
gpc.is_initialized(ParallelMode.PARALLEL_2P5D_XZ), \
'Both PARALLEL_2P5D_COL, PARALLEL_2P5D_ROW, PARALLEL_2P5D_DEP and PARALLEL_2P5D_XZ must be initialized by the process group initializer'

351
colossalai/nn/layer/parallel_2p5d/_vit.py

@ -0,0 +1,351 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import math
import torch
from torch import nn as nn, Tensor, distributed as dist
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.nn.layer.vanilla_vision_transformer.layers import to_2tuple
from colossalai.registry import LAYERS
from colossalai.utils import get_current_device
from ._operation import _ViT_Split_2p5D
from ._utils import assert_tesseract_initialization, \
get_tesseract_dim_dep_from_env
from .layers import Linear2p5D
from .._common_utils import ACT2FN, divide, CheckpointModule
from .._common_utils import set_tensor_parallel_attribute
@LAYERS.register_module
class ViTMLP2p5D(CheckpointModule):
"""MLP layer for 2.5D parallel Vision Transformer
:param in_features: size of each input sample
:type in_features: int
:param mlp_ratio: hidden size of MLP divided by embedding dim
:type mlp_ratio: int
:param act_func: activation function, defaults to 'gelu'
:type act_func: str, optional
:param dropout_prob: dropout probability, defaults to 0.
:type dropout_prob: float, optional
:param dtype: The dtype of parameters, defaults to None
:type dtype: torch.dtype, optional
:param checkpoint: If set to `True`, activation checkpoint is used, defaults to `False`
:type checkpoint: bool, optional
"""
def __init__(self,
in_features: int,
mlp_ratio: int,
act_func: str = 'gelu',
dropout_prob: float = 0.,
dtype=None,
checkpoint: bool = False
):
super().__init__(checkpoint=checkpoint)
assert_tesseract_initialization()
self.in_features = in_features
self.mlp_ratio = mlp_ratio
# Project to mlp_ratio * h.
self.dense_1 = Linear2p5D(
self.in_features,
self.mlp_ratio * self.in_features,
dtype=dtype,
)
self.act = ACT2FN[act_func]
# Project back to h.
self.dense_2 = Linear2p5D(
self.mlp_ratio * self.in_features,
self.in_features,
dtype=dtype,
)
self.dropout = nn.Dropout(dropout_prob)
def _forward(self, hidden_states: Tensor) -> Tensor:
intermediate_output = self.dense_1(hidden_states)
intermediate_output = self.act(intermediate_output)
intermediate_output = self.dropout(intermediate_output)
output = self.dense_2(intermediate_output)
output = self.dropout(output)
return output
@LAYERS.register_module
class ViTSelfAttention2p5D(CheckpointModule):
"""Self-attention layer for 2.5D parallel Vision Transformer
:param hidden_size: hidden size
:type hidden_size: int
:param num_attention_heads: number of attention heads
:type num_attention_heads: int
:param attention_dropout_prob: dropout probability for attention layers
:type attention_dropout_prob: float
:param hidden_dropout_prob: dropout probability for hidden layers
:type hidden_dropout_prob: float
:param dtype: dtype of parameters, defaults to None
:type dtype: torch.dtype, optional
:param checkpoint: If set to `True`, activation checkpoint is used, defaults to `False`
:type checkpoint: bool, optional
"""
def __init__(self,
hidden_size,
num_attention_heads,
attention_dropout_prob,
hidden_dropout_prob,
dtype=None,
checkpoint: bool = False
):
super().__init__(checkpoint=checkpoint)
assert_tesseract_initialization()
self.tesseract_dim, self.tesseract_dep = get_tesseract_dim_dep_from_env()
self.hidden_size = hidden_size
self.num_attention_heads = divide(
num_attention_heads, self.tesseract_dim) # *
self.attention_head_size = divide(hidden_size, num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.query_key_value = Linear2p5D(
hidden_size,
3 * hidden_size,
dtype=dtype,
)
self.attention_dropout = nn.Dropout(attention_dropout_prob)
self.dense = Linear2p5D(
hidden_size,
hidden_size,
dtype=dtype,
)
self.dropout = nn.Dropout(hidden_dropout_prob)
def _forward(self, hidden_states: Tensor) -> Tensor:
query_key_value = self.query_key_value(hidden_states)
new_qkv_shape = query_key_value.shape[:-1] + \
(self.num_attention_heads, 3 * self.attention_head_size)
query_key_value = query_key_value.view(new_qkv_shape)
query_key_value = query_key_value.permute((0, 2, 1, 3))
query_layer, key_layer, value_layer = torch.chunk(
query_key_value, 3, dim=-1)
attention_scores = torch.matmul(
query_layer, key_layer.transpose(-1, -2))
attention_scores = attention_scores / \
math.sqrt(self.attention_head_size)
attention_probs = nn.Softmax(dim=-1)(attention_scores)
attention_probs = self.attention_dropout(attention_probs)
context_layer = torch.matmul(attention_probs, value_layer)
context_layer = context_layer.transpose(1, 2)
new_context_layer_shape = context_layer.size()[
:-2] + (self.all_head_size,)
context_layer = context_layer.reshape(new_context_layer_shape)
output = self.dense(context_layer)
output = self.dropout(output)
return output
@LAYERS.register_module
class ViTHead2p5D(nn.Module):
"""Output layer for 2.5D parallel Vision Transformer
:param hidden_size: hidden size
:type hidden_size: int
:param num_classes: number of classes
:type num_classes: int
:param dtype: dtype of parameters, defaults to None
:type dtype: torch.dtype, optional
"""
def __init__(self,
hidden_size,
num_classes,
dtype=None,
):
super().__init__()
assert_tesseract_initialization()
self.linear = Linear2p5D(
hidden_size,
num_classes,
dtype=dtype,
)
def forward(self, x: Tensor) -> Tensor:
x = x[:, 0]
x = self.linear(x)
return x
@LAYERS.register_module
class ViTPatchEmbedding2p5D(nn.Module):
""" 2.5D Image to Patch Embedding
:param img_size: iamge size
:type img_size: int
:param patch_size: patch size
:type patch_size: int
:param embed_dim: dimension of embedding
:type embed_dim: int
:param in_chans: number of channels of input image, defaults to 3
:type in_chans: int, optional
:param flatten: whether to flatten output tensor, defaults to True
:type flatten: bool, optional
"""
def __init__(self,
img_size,
patch_size,
embed_dim,
in_chans=3,
flatten=True):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
assert_tesseract_initialization()
self.tesseract_dim, self.tesseract_dep = get_tesseract_dim_dep_from_env()
self.img_size = img_size
self.patch_size = patch_size
self.grid_size = (img_size[0] // patch_size[0],
img_size[1] // patch_size[1])
self.num_patches = self.grid_size[0] * self.grid_size[1]
self.flatten = flatten
self.embed_dim = embed_dim // self.tesseract_dim # *
self.proj = nn.Conv2d(in_chans,
self.embed_dim,
kernel_size=patch_size,
stride=patch_size,
)
# move self to cuda before sync
self.to(get_current_device())
# sync
self._broadcast_conv_params()
self.proj.weight.register_hook(self._sync_grad_during_backward)
self.proj.bias.register_hook(self._sync_grad_during_backward)
def _broadcast_conv_params(self) -> None:
xz_rank = gpc.get_ranks_in_group(ParallelMode.PARALLEL_2P5D_XZ)
dist.broadcast(self.proj.weight, src=xz_rank[0],
group=gpc.get_group(ParallelMode.PARALLEL_2P5D_XZ))
dist.broadcast(self.proj.bias, src=xz_rank[0],
group=gpc.get_group(ParallelMode.PARALLEL_2P5D_XZ))
def _sync_grad_during_backward(self, grad: Tensor) -> None:
dist.all_reduce(grad, group=gpc.get_group(
ParallelMode.PARALLEL_2P5D_XZ))
grad = grad / self.tesseract_dim / self.tesseract_dep # *
return grad
def forward(self, x: Tensor) -> Tensor:
B, C, H, W = x.shape
assert H == self.img_size[0] and W == self.img_size[1], \
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
x = self.proj(x)
if self.flatten:
x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
return x
@LAYERS.register_module
class ViTTokenFuser2p5D(nn.Module):
"""
Fuse cls token and pos embedding to the input
:param img_size: image size
:type img_size: int
:param patch_size: patch size
:type patch_size: int
:param embed_dim: dimension of embedding
:type embed_dim: int
:param drop_rate: dropout probability, defaults to 0.
:type drop_rate: float, optional
"""
def __init__(self,
img_size,
patch_size,
embed_dim,
drop_rate=0.
):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
assert_tesseract_initialization()
self.tesseract_dim, self.tesseract_dep = get_tesseract_dim_dep_from_env()
self.img_size = img_size
self.patch_size = patch_size
self.grid_size = (img_size[0] // patch_size[0],
img_size[1] // patch_size[1])
self.num_patches = self.grid_size[0] * self.grid_size[1]
self.embed_dim = embed_dim
self.cls_token = nn.Parameter(torch.zeros(
1, 1, self.embed_dim // self.tesseract_dim)) # *
self.pos_embed = nn.Parameter(torch.zeros(
1, self.num_patches + 1, self.embed_dim // self.tesseract_dim)) # *
# move to cuda before broadcast
self.to(get_current_device())
self._broadcast_params()
self.cls_token.register_hook(self._sync_grad_hook)
self.pos_embed.register_hook(self._sync_grad_hook)
self.pos_drop = nn.Dropout(p=drop_rate)
self._set_tensor_parallel_attribute()
def _set_tensor_parallel_attribute(self):
set_tensor_parallel_attribute(self.cls_token)
set_tensor_parallel_attribute(self.pos_embed)
def _broadcast_params(self) -> None:
" broadcast to all column ranks for data consistency "
xz_rank = gpc.get_ranks_in_group(ParallelMode.PARALLEL_2P5D_XZ)
dist.broadcast(self.cls_token, src=xz_rank[0],
group=gpc.get_group(ParallelMode.PARALLEL_2P5D_XZ))
dist.broadcast(self.pos_embed, src=xz_rank[0],
group=gpc.get_group(ParallelMode.PARALLEL_2P5D_XZ))
def _sync_grad_hook(self, grad) -> None:
dist.all_reduce(grad, group=gpc.get_group(
ParallelMode.PARALLEL_2P5D_XZ))
grad = grad / self.tesseract_dim / self.tesseract_dep # *
return grad
def forward(self, x: Tensor) -> Tensor:
# stole cls_tokens impl from Phil Wang, thanks
cls_token = self.cls_token.expand(x.shape[0], -1, -1)
x = torch.cat((cls_token, x), dim=1)
x = self.pos_drop(x + self.pos_embed)
return x
@LAYERS.register_module
class ViTInputSplitter2p5D(nn.Module):
def __init__(self):
super().__init__()
assert_tesseract_initialization()
self.tesseract_dim, self.tesseract_dep = get_tesseract_dim_dep_from_env()
def forward(self, x: Tensor) -> Tensor:
batch_size = x.size(0)
return _ViT_Split_2p5D.apply(
x,
batch_size,
self.tesseract_dim,
self.tesseract_dep,
ParallelMode.PARALLEL_2P5D_XZ,
)

266
colossalai/nn/layer/parallel_2p5d/layers.py

@ -0,0 +1,266 @@
import math
import torch
from torch import Tensor
from torch.nn import Parameter, init as init
from colossalai.context import seed, ParallelMode
from colossalai.core import global_context as gpc
from colossalai.registry import LAYERS
from colossalai.utils import get_current_device
from ._operation import Matmul_AB_2p5D, Add_Bias_2p5D, _LayerNorm_2p5D
from ._utils import get_tesseract_dim_dep_from_env, assert_tesseract_initialization
from .._common_utils import divide, set_tensor_parallel_attribute
from ..base_layer import ParallelLayer
@LAYERS.register_module
class Linear2p5D(ParallelLayer):
"""Linear layer for 2.5D parallelism
:param in_features: size of each input sample
:type in_features: int
:param out_features: size of each output sample
:type out_features: int
:param bias: If set to ``False``, the layer will not learn an additive bias, defaults to True
:type bias: bool, optional
:param dtype: The dtype of parameters, defaults to None
:type dtype: torch.dtype, optional
"""
def __init__(self,
in_features: int,
out_features: int,
bias: bool = True,
dtype=None,
skip_bias_add: bool = False
):
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.skip_bias_add = skip_bias_add
# parallel setting
assert_tesseract_initialization()
self.row_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)
self.col_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW)
self.dep_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP)
self.tesseract_dim, self.tesseract_dep = get_tesseract_dim_dep_from_env()
# partitioning dimension
self.input_size_per_partition = divide(in_features, self.tesseract_dim)
self.hidden_size_per_partition = divide(
out_features, self.tesseract_dim)
# create weight, shape: [k/q, h/q]
factory_kwargs = {'device': get_current_device(), 'dtype': dtype}
self.weight = Parameter(torch.empty(
self.input_size_per_partition,
self.hidden_size_per_partition,
**factory_kwargs))
# create bias, shape: [h/q]
if bias:
self.bias = Parameter(torch.empty(
self.hidden_size_per_partition,
**factory_kwargs))
else:
self.register_parameter('bias', None)
# initialize parameters
self.reset_parameters()
self._set_tensor_parallel_attributes()
def _set_tensor_parallel_attributes(self):
set_tensor_parallel_attribute(self.weight)
if self.bias is not None:
set_tensor_parallel_attribute(self.bias)
def reset_parameters(self) -> None:
# setting
fan_in = self.in_features
a = math.sqrt(5)
nonlinearity = 'leaky_relu'
# init weight
std = init.calculate_gain(nonlinearity, a) / math.sqrt(fan_in)
bound = math.sqrt(3.0) * std
with seed(ParallelMode.TENSOR):
init.uniform_(self.weight, -bound, bound)
# init bias
if self.bias is not None:
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
with seed(ParallelMode.TENSOR):
init.uniform_(self.bias, -bound, bound)
def forward(self, x: Tensor) -> Tensor:
# input: [m/dq, n/q, k/q]
# output: [m/dq, n/q, h/q]
out_shape = x.shape[:-1] + (self.hidden_size_per_partition,)
output = Matmul_AB_2p5D.apply(
x,
self.weight,
self.tesseract_dim,
self.tesseract_dep,
out_shape,
self.row_rank, self.col_rank, self.dep_rank,
ParallelMode.PARALLEL_2P5D_ROW,
ParallelMode.PARALLEL_2P5D_COL,
ParallelMode.PARALLEL_2P5D_DEP,
self.data_parallel_rank,
self.pipeline_parallel_rank,
self.pipeline_parallel_size,
self.tensor_parallel_size,
)
if self.bias is not None:
if self.skip_bias_add:
bias = Add_Bias_2p5D.apply(
None,
self.bias,
self.hidden_size_per_partition,
self.tesseract_dim, self.tesseract_dep,
self.row_rank, self.col_rank, self.dep_rank,
ParallelMode.PARALLEL_2P5D_ROW,
ParallelMode.PARALLEL_2P5D_COL,
ParallelMode.PARALLEL_2P5D_DEP,
True,
self.data_parallel_rank,
self.pipeline_parallel_rank,
self.pipeline_parallel_size,
self.tensor_parallel_size
)
return output, bias
else:
output = Add_Bias_2p5D.apply(
output,
self.bias,
self.hidden_size_per_partition,
self.tesseract_dim, self.tesseract_dep,
self.row_rank, self.col_rank, self.dep_rank,
ParallelMode.PARALLEL_2P5D_ROW,
ParallelMode.PARALLEL_2P5D_COL,
ParallelMode.PARALLEL_2P5D_DEP,
False,
self.data_parallel_rank,
self.pipeline_parallel_rank,
self.pipeline_parallel_size,
self.tensor_parallel_size
)
return output
else:
return output
@LAYERS.register_module
class LayerNorm2p5D(ParallelLayer):
r"""Layer Normalization for 2.5D parallelism
:param normalized_shape: input shape from an expected input
of size. :math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1] \times \ldots \times \text{normalized_shape}[-1]]`
If a single integer is used, it is treated as a singleton list, and this module will
normalize over the last dimension which is expected to be of that specific size.
:type normalized_shape: int
:param eps: a value added to the denominator for numerical stability, defaults to 1e-05
:type eps: float, optional
:param dtype: The dtype of parameters, defaults to None
:type dtype: torch.dtype, optional
"""
def __init__(self,
normalized_shape: int,
eps: float = 1e-05,
dtype=None
):
super().__init__()
# layer norm config
self.normalized_shape = normalized_shape
self.variance_epsilon = eps
# parallel setting
assert_tesseract_initialization()
self.row_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)
self.col_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW)
self.dep_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP)
self.tesseract_dim, self.tesseract_dep = get_tesseract_dim_dep_from_env()
# partitioning dimension
self.partitioned_partition = divide(
normalized_shape, self.tesseract_dim) # *
# create parameters
factory_kwargs = {'device': get_current_device(), 'dtype': dtype}
if self.row_rank == 0:
self.gamma = Parameter(torch.ones(
self.partitioned_partition,
**factory_kwargs))
self.beta = Parameter(torch.zeros(
self.partitioned_partition,
**factory_kwargs))
else:
self.gamma = Parameter(torch.tensor(
1.0,
requires_grad=True,
**factory_kwargs))
self.beta = Parameter(torch.tensor(
1.0,
requires_grad=True,
**factory_kwargs))
self._set_tensor_parallel_attribute()
def _set_tensor_parallel_attribute(self):
set_tensor_parallel_attribute(self.gamma)
set_tensor_parallel_attribute(self.beta)
def forward(self, x: Tensor) -> Tensor:
with torch.no_grad():
E_x = torch.sum(x, dim=-1, keepdim=True) # [b/q, s, 1]
torch.distributed.all_reduce(
E_x, group=gpc.get_group(ParallelMode.PARALLEL_2P5D_ROW))
E_x /= self.normalized_shape
# Var_x in the block below is the sum of input^2
Var_x = torch.sum(x * x, dim=-1, keepdim=True) # [b/q, s, 1]
torch.distributed.all_reduce(
Var_x, group=gpc.get_group(ParallelMode.PARALLEL_2P5D_ROW))
Var_x /= self.normalized_shape
Var_x = Var_x - E_x * E_x # variance of x [b/q, s, 1]
# this time 1/sqrt(Var_x + epsilon)
Var_x = 1.0 / torch.sqrt(Var_x + self.variance_epsilon)
output = _LayerNorm_2p5D.apply(x, E_x, Var_x, self.normalized_shape,
ParallelMode.PARALLEL_2P5D_ROW,
ParallelMode.PARALLEL_2P5D_COL,
ParallelMode.PARALLEL_2P5D_DEP)
bias = Add_Bias_2p5D.apply(
None, self.beta, self.partitioned_partition,
self.tesseract_dim, self.tesseract_dep,
self.row_rank, self.col_rank, self.dep_rank,
ParallelMode.PARALLEL_2P5D_ROW,
ParallelMode.PARALLEL_2P5D_COL,
ParallelMode.PARALLEL_2P5D_DEP,
True,
self.data_parallel_rank,
self.pipeline_parallel_rank,
self.pipeline_parallel_size,
self.tensor_parallel_size
)
scale = Add_Bias_2p5D.apply(
None, self.gamma, self.partitioned_partition,
self.tesseract_dim, self.tesseract_dep,
self.row_rank, self.col_rank, self.dep_rank,
ParallelMode.PARALLEL_2P5D_ROW,
ParallelMode.PARALLEL_2P5D_COL,
ParallelMode.PARALLEL_2P5D_DEP,
True,
self.data_parallel_rank,
self.pipeline_parallel_rank,
self.pipeline_parallel_size,
self.tensor_parallel_size
)
output = torch.addcmul(bias, scale, output)
return output

9
colossalai/nn/layer/parallel_3d/__init__.py

@ -0,0 +1,9 @@
from ._operation import Matmul_ABT_3D, Matmul_ATB_3D, Matmul_AB_3D, Mul_3D, Sum_3D, Add_3D, Reduce_3D
from ._vit import ViTHead3D, ViTMLP3D, ViTPatchEmbedding3D, ViTSelfAttention3D
from .layers import Linear3D, LayerNorm3D
__all__ = [
'Matmul_ABT_3D', 'Matmul_ATB_3D', 'Matmul_AB_3D', 'Mul_3D', 'Sum_3D', 'Add_3D', 'Reduce_3D',
'ViTHead3D', 'ViTMLP3D', 'ViTPatchEmbedding3D', 'ViTSelfAttention3D',
'Linear3D', 'LayerNorm3D'
]

349
colossalai/nn/layer/parallel_3d/_operation.py

@ -0,0 +1,349 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from typing import Any, Tuple
import torch
import torch.distributed as dist
from colossalai.communication import all_gather, reduce_scatter, scatter
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.utils import empty_cache, get_current_device
from torch import Tensor
class Matmul_AB_3D(torch.autograd.Function):
"""Matrix multiplication for :math:`C = AB`
"""
@staticmethod
def forward(ctx: Any,
A: Tensor,
B: Tensor,
depth: int,
input_parallel_mode: ParallelMode,
weight_parallel_mode: ParallelMode,
output_parallel_mode: ParallelMode,
input_dim: int = 0,
weight_dim: int = -1,
output_dim: int = 0) -> Tensor:
# A: [m/q^2, n, k/q]
# B: [k/q, h/q^2]
# C: [m/q^2, n, h/q]
empty_cache()
ctx.save_for_backward(A, B)
assert A.shape[-1] == B.shape[0], \
'Invalid shapes: A={}, B={}.'.format(A.shape, B.shape)
A_temp = all_gather(A, input_dim, input_parallel_mode)
B_temp = all_gather(B, weight_dim, weight_parallel_mode)
C = torch.matmul(A_temp, B_temp)
out = reduce_scatter(C, output_dim, output_parallel_mode)
ctx.depth = depth
ctx.A_group_parallel_mode = input_parallel_mode
ctx.B_group_parallel_mode = weight_parallel_mode
ctx.C_group_parallel_mode = output_parallel_mode
ctx.A_dim = input_dim
ctx.B_dim = weight_dim
ctx.C_dim = output_dim
return out
@staticmethod
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
A, B = ctx.saved_tensors
with torch.no_grad():
A_grad = Matmul_ABT_3D.apply(output_grad, B, ctx.depth,
ctx.C_group_parallel_mode,
ctx.B_group_parallel_mode,
ctx.A_group_parallel_mode, ctx.C_dim,
ctx.B_dim, ctx.A_dim)
B_grad = Matmul_ATB_3D.apply(A, output_grad, ctx.depth,
ctx.A_group_parallel_mode,
ctx.C_group_parallel_mode,
ctx.B_group_parallel_mode, ctx.A_dim,
ctx.C_dim, ctx.B_dim)
return A_grad, B_grad, None, None, None, None, None, None, None
class Matmul_ABT_3D(torch.autograd.Function):
"""Matrix multiplication for :math:`C = AB^T`
"""
@staticmethod
def forward(ctx: Any,
A: Tensor,
B: Tensor,
depth: int,
input_parallel_mode: ParallelMode,
weight_parallel_mode: ParallelMode,
output_parallel_mode: ParallelMode,
input_dim: int = 0,
weight_dim: int = -1,
output_dim: int = 0) -> Tensor:
# A: [m/q^2, n, h/q]
# B: [k/q, h/q^2]
# C: [m/q^2, n, k/q]
empty_cache()
ctx.save_for_backward(A, B)
A_temp = all_gather(A, input_dim, input_parallel_mode)
B_temp = all_gather(B, weight_dim, weight_parallel_mode)
C = torch.matmul(A_temp, B_temp.transpose(0, 1))
out = reduce_scatter(C, output_dim, output_parallel_mode)
ctx.depth = depth
ctx.A_group_parallel_mode = input_parallel_mode
ctx.B_group_parallel_mode = weight_parallel_mode
ctx.C_group_parallel_mode = output_parallel_mode
ctx.A_dim = input_dim
ctx.B_dim = weight_dim
ctx.C_dim = output_dim
return out
@staticmethod
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
A, B = ctx.saved_tensors
with torch.no_grad():
A_grad = Matmul_AB_3D.apply(output_grad, B, ctx.depth,
ctx.C_group_parallel_mode,
ctx.B_group_parallel_mode,
ctx.A_group_parallel_mode, ctx.C_dim,
ctx.B_dim, ctx.A_dim)
B_grad = Matmul_ATB_3D.apply(output_grad, A, ctx.depth,
ctx.C_group_parallel_mode,
ctx.A_group_parallel_mode,
ctx.B_group_parallel_mode, ctx.C_dim,
ctx.A_dim, ctx.B_dim)
return A_grad, B_grad, None, None, None, None, None, None, None
class Matmul_ATB_3D(torch.autograd.Function):
"""Matrix multiplication for :math:`C = A^TB`
"""
@staticmethod
def forward(ctx: Any,
A: Tensor,
B: Tensor,
depth: int,
input_parallel_mode: ParallelMode,
weight_parallel_mode: ParallelMode,
output_parallel_mode: ParallelMode,
input_dim: int = 0,
weight_dim: int = 0,
output_dim: int = -1) -> Tensor:
# A: [m/q^2, n, k/q]
# B: [m/q^2, n, h/q]
# C: [k/q, h/q^2]
empty_cache()
ctx.save_for_backward(A, B)
A_temp = all_gather(A, input_dim, input_parallel_mode)
A_temp = A_temp.reshape(-1, A.shape[-1])
B_temp = all_gather(B, weight_dim, weight_parallel_mode)
B_temp = B_temp.reshape(-1, B.shape[-1])
C = torch.matmul(A_temp.transpose(0, 1), B_temp)
out = reduce_scatter(C, output_dim, output_parallel_mode)
ctx.depth = depth
ctx.A_group_parallel_mode = input_parallel_mode
ctx.B_group_parallel_mode = weight_parallel_mode
ctx.C_group_parallel_mode = output_parallel_mode
ctx.A_dim = input_dim
ctx.B_dim = weight_dim
ctx.C_dim = output_dim
return out
@staticmethod
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
A, B = ctx.saved_tensors
with torch.no_grad():
A_grad = Matmul_ABT_3D.apply(B, output_grad, ctx.depth,
ctx.B_group_parallel_mode,
ctx.C_group_parallel_mode,
ctx.A_group_parallel_mode, ctx.B_dim,
ctx.C_dim, ctx.A_dim)
B_grad = Matmul_AB_3D.apply(A, output_grad, ctx.depth,
ctx.A_group_parallel_mode,
ctx.C_group_parallel_mode,
ctx.B_group_parallel_mode, ctx.A_dim,
ctx.C_dim, ctx.B_dim)
return A_grad, B_grad, None, None, None, None, None, None, None
class Add_3D(torch.autograd.Function):
"""Matrix add bias: :math:`C = A + b`
"""
@staticmethod
def forward(ctx: Any, input_: Tensor, bias: Tensor, depth: int,
input_parallel_mode: ParallelMode,
weight_parallel_mode: ParallelMode,
output_parallel_mode: ParallelMode) -> Tensor:
# input: [m/q^2, n, h/q]
# bias: [h/q^2]
ranks_in_group = gpc.get_ranks_in_group(input_parallel_mode)
src_rank = ranks_in_group[gpc.get_local_rank(output_parallel_mode)]
bias_temp = bias.clone()
dist.broadcast(bias_temp,
src=src_rank,
group=gpc.get_group(input_parallel_mode))
# [h/q]
bias_temp = all_gather(bias_temp, -1, weight_parallel_mode)
out = input_ + bias_temp
ctx.depth = depth
ctx.src_rank = src_rank
ctx.A_group_parallel_mode = input_parallel_mode
ctx.B_group_parallel_mode = weight_parallel_mode
ctx.C_group_parallel_mode = output_parallel_mode
return out
@staticmethod
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
# output_grad: [m/q^2, n, h/q]
with torch.no_grad():
# [h/q]
grad = torch.sum(output_grad,
dim=tuple(range(len(output_grad.shape))[:-1]))
bias_grad = reduce_scatter(grad, -1, ctx.B_group_parallel_mode)
dist.reduce(bias_grad,
dst=ctx.src_rank,
group=gpc.get_group(ctx.A_group_parallel_mode))
if gpc.get_local_rank(
ctx.A_group_parallel_mode) != gpc.get_local_rank(
ctx.C_group_parallel_mode):
bias_grad = None
return output_grad, bias_grad, None, None, None, None
class Mul_3D(torch.autograd.Function):
"""Matrix multiplication for :math:`C = A * b`
"""
@staticmethod
def forward(ctx: Any, input_: Tensor, bias: Tensor, depth: int,
input_parallel_mode: ParallelMode,
weight_parallel_mode: ParallelMode,
output_parallel_mode: ParallelMode) -> Tensor:
# input: [m/q^2, n, h/q]
# bias: [h/q^2]
ranks_in_group = gpc.get_ranks_in_group(input_parallel_mode)
src_rank = ranks_in_group[gpc.get_local_rank(output_parallel_mode)]
# [h/q^2]
bias_temp = bias.clone()
dist.broadcast(bias_temp,
src=src_rank,
group=gpc.get_group(input_parallel_mode))
# [h/q]
bias_temp = all_gather(bias_temp, -1, weight_parallel_mode)
empty_cache()
ctx.save_for_backward(input_, bias_temp)
out = torch.mul(input_, bias_temp)
ctx.depth = depth
ctx.src_rank = src_rank
ctx.A_group_parallel_mode = input_parallel_mode
ctx.B_group_parallel_mode = weight_parallel_mode
ctx.C_group_parallel_mode = output_parallel_mode
return out
@staticmethod
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
# output_grad: [m/q^2, n, h/q]
with torch.no_grad():
input_, bias = ctx.saved_tensors
# [m/q^2, n, h/q]
input_grad = torch.mul(output_grad, bias)
# [h/q]
grad = torch.mul(output_grad, input_)
grad = torch.sum(grad,
dim=tuple(range(len(output_grad.shape))[:-1]))
bias_grad = reduce_scatter(grad, -1, ctx.B_group_parallel_mode)
dist.reduce(bias_grad,
dst=ctx.src_rank,
group=gpc.get_group(ctx.A_group_parallel_mode))
if gpc.get_local_rank(
ctx.A_group_parallel_mode) != gpc.get_local_rank(
ctx.C_group_parallel_mode):
bias_grad = None
return input_grad, bias_grad, None, None, None, None
class Sum_3D(torch.autograd.Function):
"""Compute the sum of input tensors
"""
@staticmethod
def forward(ctx: Any,
input_: Tensor,
dim: int,
depth: int,
parallel_mode: ParallelMode,
keepdim: bool = False) -> Tensor:
# input: [m/q^2, n, h/q]
out = torch.sum(input_, dim=dim, keepdim=keepdim)
dist.all_reduce(out, group=gpc.get_group(parallel_mode))
ctx.input_shape = input_.shape
ctx.depth = depth
ctx.group = parallel_mode
ctx.dim = dim
return out
@staticmethod
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
with torch.no_grad():
output_grad = output_grad.contiguous()
dist.all_reduce(output_grad, group=gpc.get_group(ctx.group))
if len(output_grad.shape) < len(ctx.input_shape):
output_grad = torch.unsqueeze(output_grad, ctx.dim)
dims = [1 for _ in range(len(output_grad.shape))]
dims[ctx.dim] = ctx.input_shape[ctx.dim]
input_grad = output_grad.repeat(tuple(dims))
return input_grad, None, None, None, None, None
class Reduce_3D(torch.autograd.Function):
"""Reduce input tensors
"""
@staticmethod
def forward(ctx: Any, input_: Tensor, depth: int,
parallel_mode: ParallelMode) -> Tensor:
dist.all_reduce(input_, group=gpc.get_group(parallel_mode))
return input_.clone()
@staticmethod
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
return output_grad, None, None
class Slice_3D(torch.autograd.Function):
"""Slice input tensor
"""
@staticmethod
def forward(ctx: Any, input_: Tensor, dim: int, depth: int,
parallel_mode: ParallelMode) -> Tensor:
rank = gpc.get_local_rank(parallel_mode)
out = torch.chunk(input_, depth, dim=dim)[rank].contiguous()
ctx.depth = depth
ctx.parallel_mode = parallel_mode
ctx.dim = dim
ctx.input_shape = input_.shape
return out
@staticmethod
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
with torch.no_grad():
input_grad = all_gather(output_grad, ctx.dim, ctx.parallel_mode)
input_grad.reshape(ctx.input_shape)
return input_grad, None, None, None

49
colossalai/nn/layer/parallel_3d/_utils.py

@ -0,0 +1,49 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import os
from colossalai.constants import DEPTH_3D
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from torch import Tensor
def get_depth_from_env() -> int:
try:
depth = os.environ[DEPTH_3D]
depth = int(depth)
assert depth > 0, 'DEPTH must be greater than zero'
return depth
except KeyError as e:
raise EnvironmentError(
'DEPTH is not found in the current environment, '
'please make sure that you have used the correct process group initializer'
)
def get_last_group(a, b):
mapping = {
ParallelMode.PARALLEL_3D_INPUT: 'A',
ParallelMode.PARALLEL_3D_WEIGHT: 'B',
ParallelMode.PARALLEL_3D_OUTPUT: 'C',
}
res = chr(
ord('A') + ord('B') + ord('C') - ord(mapping[a]) - ord(mapping[b]))
if res == 'A':
return ParallelMode.PARALLEL_3D_INPUT
elif res == 'B':
return ParallelMode.PARALLEL_3D_WEIGHT
elif res == 'C':
return ParallelMode.PARALLEL_3D_OUTPUT
def dbg_check_shape(tensor: Tensor, shape: tuple):
rank = gpc.get_global_rank()
if rank == 0:
print(tensor.shape)
assert tensor.shape == shape, \
'{} does not match {}'.format(tensor.shape, shape)

368
colossalai/nn/layer/parallel_3d/_vit.py

@ -0,0 +1,368 @@
import math
from typing import Tuple
import torch
import torch.distributed as dist
from colossalai.context import ParallelMode, seed
from colossalai.core import global_context as gpc
from colossalai.registry import LAYERS
from colossalai.utils import checkpoint, get_current_device
from torch import Tensor, dtype, nn
from .._common_utils import ACT2FN, divide, set_tensor_parallel_attribute
from ..vanilla_vision_transformer.layers import to_2tuple
from ._utils import get_depth_from_env
from .layers import Linear3D
@LAYERS.register_module
class ViTPatchEmbedding3D(nn.Module):
""" 3D Image to Patch Embedding
:param img_size: iamge size
:type img_size: int
:param patch_size: patch size
:type patch_size: int
:param in_chans: number of channels of input image
:type in_chans: int
:param embed_size: dimension of embedding
:type embed_size: int
:param drop_prob: dropout probability
:type drop_prob: float
:param flatten: whether to flatten output tensor, defaults to True
:type flatten: bool, optional
"""
def __init__(self,
img_size: int,
patch_size: int,
in_chans: int,
embed_size: int,
drop_prob: float,
flatten: bool = True):
super().__init__()
self.depth = get_depth_from_env()
self.input_parallel_mode = ParallelMode.PARALLEL_3D_INPUT
self.weight_parallel_mode = ParallelMode.PARALLEL_3D_WEIGHT
self.output_parallel_mode = ParallelMode.PARALLEL_3D_OUTPUT
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
self.img_size = img_size
self.patch_size = patch_size
self.grid_size = (img_size[0] // patch_size[0],
img_size[1] // patch_size[1])
self.embed_size = embed_size
self.embed_size_per_partition = divide(self.embed_size, self.depth)
self.num_patches = self.grid_size[0] * self.grid_size[1]
self.flatten = flatten
with seed(ParallelMode.TENSOR):
self.proj = nn.Conv2d(in_chans,
self.embed_size_per_partition,
kernel_size=patch_size,
stride=patch_size)
self.cls_token = nn.Parameter(
torch.zeros(1, 1, self.embed_size_per_partition))
self.pos_embed = nn.Parameter(
torch.zeros(1, self.num_patches + 1,
self.embed_size_per_partition))
self.pos_drop = nn.Dropout(drop_prob)
self._sync_parameters()
self.proj.weight.register_hook(self._sync_grad_hook)
self.proj.bias.register_hook(self._sync_grad_hook)
self.cls_token.register_hook(self._sync_grad_hook)
self.pos_embed.register_hook(self._sync_grad_hook)
self._set_tensor_parallel_attribute()
def _set_tensor_parallel_attribute(self):
set_tensor_parallel_attribute(self.proj.weight)
set_tensor_parallel_attribute(self.proj.bias)
set_tensor_parallel_attribute(self.cls_token)
set_tensor_parallel_attribute(self.pos_embed)
def groups_for_next_layer(self) -> Tuple[ParallelMode, ParallelMode]:
return self.input_parallel_mode, self.weight_parallel_mode
def _sync_parameters(self):
self.to(get_current_device())
weight_src_rank = gpc.get_ranks_in_group(self.weight_parallel_mode)[0]
dist.broadcast(self.proj.weight,
src=weight_src_rank,
group=gpc.get_group(self.weight_parallel_mode))
dist.broadcast(self.proj.bias,
src=weight_src_rank,
group=gpc.get_group(self.weight_parallel_mode))
input_src_rank = gpc.get_ranks_in_group(self.input_parallel_mode)[0]
dist.broadcast(self.proj.weight,
src=input_src_rank,
group=gpc.get_group(self.input_parallel_mode))
dist.broadcast(self.proj.bias,
src=input_src_rank,
group=gpc.get_group(self.input_parallel_mode))
set_tensor_parallel_attribute(self.proj.weight)
set_tensor_parallel_attribute(self.proj.bias)
set_tensor_parallel_attribute(self.cls_token)
set_tensor_parallel_attribute(self.pos_embed)
def _sync_grad_hook(self, grad) -> None:
dist.all_reduce(grad, group=gpc.get_group(self.input_parallel_mode))
dist.all_reduce(grad, group=gpc.get_group(self.weight_parallel_mode))
return grad
def forward(self, x: Tensor) -> Tensor:
B, C, H, W = x.shape
assert H == self.img_size[0] and W == self.img_size[1], \
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
x = self.proj(x)
if self.flatten:
x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
# split a partition from embedded states
x = torch.chunk(x, self.depth, dim=0)[gpc.get_local_rank(
self.weight_parallel_mode)].contiguous()
x = torch.chunk(x, self.depth, dim=0)[gpc.get_local_rank(
self.input_parallel_mode)].contiguous()
# add cls token & pos embedding
# [b/q^2,s,h/q] --> [b/q^2, 1+s, h/q]
cls_token = self.cls_token.expand(x.shape[0], -1, -1)
x = torch.cat((cls_token, x), dim=1)
with seed(ParallelMode.TENSOR):
x = self.pos_drop(x + self.pos_embed)
return x
@LAYERS.register_module
class ViTSelfAttention3D(nn.Module):
"""Self-attention layer for 3D parallel Vision Transformer
:param hidden_size: hidden size
:type hidden_size: int
:param num_attention_heads: number of attention heads
:type num_attention_heads: int
:param attention_probs_dropout_prob: dropout probability for attention layers
:type attention_probs_dropout_prob: bool
:param hidden_dropout_prob: dropout probability for hidden layers
:type hidden_dropout_prob: bool
:param depth: the 3D parallelism depth
:type depth: int
:param input_parallel_mode: parallel mode of input tensor
:type input_parallel_mode: ParallelMode
:param weight_parallel_mode: parallel mode of weight
:type weight_parallel_mode: ParallelMode
:param dtype: dtype of parameters, defaults to None
:type dtype: dtype, optional
:param bias: whether to add bias, defaults to True
:type bias: bool, optional
"""
def __init__(self,
hidden_size: int,
num_attention_heads: int,
attention_probs_dropout_prob: float,
hidden_dropout_prob: float,
dtype: dtype = None,
bias: bool = True,
checkpoint: bool = False):
super().__init__()
self.depth = get_depth_from_env()
self.input_parallel_mode = ParallelMode.PARALLEL_3D_INPUT
self.weight_parallel_mode = ParallelMode.PARALLEL_3D_WEIGHT
self.output_parallel_mode = ParallelMode.PARALLEL_3D_OUTPUT
self.hidden_size = hidden_size
self.num_attention_heads = divide(num_attention_heads, self.depth)
self.attention_head_size = divide(hidden_size, num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.checkpoint = checkpoint
self.query_key_value = Linear3D(self.hidden_size,
3 * self.hidden_size,
self.input_parallel_mode,
self.weight_parallel_mode,
dtype=dtype,
bias=bias)
self.attention_dropout = nn.Dropout(attention_probs_dropout_prob)
self.dense = Linear3D(self.hidden_size,
self.hidden_size,
self.output_parallel_mode,
self.weight_parallel_mode,
dtype=dtype,
bias=bias)
self.dropout = nn.Dropout(hidden_dropout_prob)
self.softmax = nn.Softmax(dim=-1)
def groups_for_next_layer(self) -> Tuple[ParallelMode, ParallelMode]:
return self.input_parallel_mode, self.weight_parallel_mode
def _forward(self, hidden_states: Tensor) -> Tensor:
query_key_value = self.query_key_value(hidden_states)
new_qkv_shape = query_key_value.shape[:-1] + \
(self.num_attention_heads, 3 * self.attention_head_size)
query_key_value = query_key_value.view(new_qkv_shape)
query_key_value = query_key_value.permute((0, 2, 1, 3))
query_layer, key_layer, value_layer = torch.chunk(query_key_value,
3,
dim=-1)
attention_scores = torch.matmul(query_layer,
key_layer.transpose(-1, -2))
attention_scores = attention_scores / math.sqrt(
self.attention_head_size)
attention_probs = self.softmax(attention_scores)
with seed(ParallelMode.TENSOR):
attention_probs = self.attention_dropout(attention_probs)
context_layer = torch.matmul(attention_probs, value_layer)
context_layer = context_layer.transpose(1, 2)
new_context_layer_shape = context_layer.size()[:-2] + (
self.all_head_size, )
context_layer = context_layer.reshape(new_context_layer_shape)
output = self.dense(context_layer)
with seed(ParallelMode.TENSOR):
output = self.dropout(output)
return output
def _checkpoint_forward(self, hidden_states: Tensor) -> Tensor:
return checkpoint(self._forward, hidden_states)
def forward(self, hidden_states: Tensor) -> Tensor:
if self.checkpoint:
return self._checkpoint_forward(hidden_states)
else:
return self._forward(hidden_states)
@LAYERS.register_module
class ViTMLP3D(nn.Module):
"""[summary]
:param hidden_size: hidden size
:type hidden_size: int
:param mlp_ratio: hidden size of MLP divided by embedding dim
:type mlp_ratio: int
:param hidden_dropout_prob: dropout probability for hidden layers
:type hidden_dropout_prob: float
:param hidden_act: activation function for hidden layers
:type hidden_act: str
:param depth: the 3D parallelism depth
:type depth: int
:param input_parallel_mode: parallel mode of input tensor
:type input_parallel_mode: ParallelMode
:param weight_parallel_mode: parallel mode of weight
:type weight_parallel_mode: ParallelMode
:param dtype: dtype of parameters, defaults to None
:type dtype: dtype, optional
:param bias: whether to add bias, defaults to True
:type bias: bool, optional
"""
def __init__(self,
hidden_size: int,
mlp_ratio: int,
hidden_dropout_prob: float,
hidden_act: str = 'gelu',
dtype: dtype = None,
bias: bool = True,
checkpoint: bool = False):
super().__init__()
self.depth = get_depth_from_env()
self.input_parallel_mode = ParallelMode.PARALLEL_3D_INPUT
self.weight_parallel_mode = ParallelMode.PARALLEL_3D_WEIGHT
self.output_parallel_mode = ParallelMode.PARALLEL_3D_OUTPUT
self.hidden_size = hidden_size
self.mlp_ratio = mlp_ratio
self.checkpoint = checkpoint
self.dense_1 = Linear3D(self.hidden_size,
self.mlp_ratio * self.hidden_size,
self.input_parallel_mode,
self.weight_parallel_mode,
dtype=dtype,
bias=bias)
self.activation_func = ACT2FN[hidden_act]
self.dense_2 = Linear3D(self.mlp_ratio * self.hidden_size,
self.hidden_size,
self.output_parallel_mode,
self.weight_parallel_mode,
dtype=dtype,
bias=bias)
self.dropout = nn.Dropout(hidden_dropout_prob)
def groups_for_next_layer(self) -> Tuple[ParallelMode, ParallelMode]:
return self.input_parallel_mode, self.weight_parallel_mode
def _forward(self, hidden_states: Tensor) -> Tensor:
intermediate_output = self.dense_1(hidden_states)
intermediate_output = self.activation_func(intermediate_output)
output = self.dense_2(intermediate_output)
with seed(ParallelMode.TENSOR):
output = self.dropout(output)
return output
def _checkpoint_forward(self, hidden_states: Tensor) -> Tensor:
return checkpoint(self._forward, hidden_states)
def forward(self, hidden_states: Tensor) -> Tensor:
if self.checkpoint:
return self._checkpoint_forward(hidden_states)
else:
return self._forward(hidden_states)
@LAYERS.register_module
class ViTHead3D(nn.Module):
"""Output layer for 3D parallel Vision Transformer
:param in_features: size of input tensor
:type in_features: int
:param num_classes: number of classes
:type num_classes: int
:param depth: the 3D parallelism depth
:type depth: int
:param input_parallel_mode: parallel mode of input tensor
:type input_parallel_mode: ParallelMode
:param weight_parallel_mode: parallel mode of weight
:type weight_parallel_mode: ParallelMode
:param dtype: dtype of parameters, defaults to None
:type dtype: dtype, optional
:param bias: whether to add bias, defaults to True
:type bias: bool, optional
"""
def __init__(self,
in_features: int,
num_classes: int,
dtype: dtype = None,
bias: bool = True):
super().__init__()
self.depth = get_depth_from_env()
self.input_parallel_mode = ParallelMode.PARALLEL_3D_INPUT
self.weight_parallel_mode = ParallelMode.PARALLEL_3D_WEIGHT
self.output_parallel_mode = ParallelMode.PARALLEL_3D_OUTPUT
self.in_features = in_features
self.num_classes = num_classes
out_features = math.ceil(self.num_classes /
(self.depth**2)) * (self.depth**2)
self.num_classes_per_partition = divide(self.num_classes, self.depth)
self.linear = Linear3D(self.in_features,
out_features,
self.input_parallel_mode,
self.weight_parallel_mode,
dtype=dtype,
bias=bias)
def groups_for_next_layer(self) -> Tuple[ParallelMode, ParallelMode]:
return self.linear.groups_for_next_layer()
def forward(self, x: Tensor) -> Tensor:
# [b/q^2, s, h/q] --> [b/q^2, h/q]
x = x[:, 0]
# [b/q^2, h/q] --> [b/q^2, c/q]
x = self.linear(x)
return x[:, :self.num_classes_per_partition]
def extra_repr(self):
return 'in_features={}, num_classes={}'.format(self.in_features,
self.num_classes)

172
colossalai/nn/layer/parallel_3d/layers.py

@ -0,0 +1,172 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import math
from typing import Tuple
import torch
import torch.nn as nn
from colossalai.context import ParallelMode, seed
from colossalai.registry import LAYERS
from colossalai.utils import get_current_device
from torch import Tensor, dtype
from torch.nn import Parameter
from .._common_utils import divide, set_tensor_parallel_attribute
from ._operation import Add_3D, Matmul_AB_3D, Mul_3D, Sum_3D
from ._utils import get_depth_from_env, get_last_group
@LAYERS.register_module
class LayerNorm3D(nn.Module):
def __init__(
self,
normalized_shape: int,
input_parallel_mode: ParallelMode,
weight_parallel_mode: ParallelMode,
eps: float = 1e-12,
dtype: dtype = None,
):
super().__init__()
self.input_parallel_mode = input_parallel_mode
self.weight_parallel_mode = weight_parallel_mode
self.output_parallel_mode = get_last_group(self.input_parallel_mode,
self.weight_parallel_mode)
self.depth = get_depth_from_env()
self.normalized_shape = normalized_shape
self.normalized_shape_per_partition = divide(normalized_shape,
self.depth**2)
self.weight = Parameter(
torch.ones(self.normalized_shape_per_partition,
device=get_current_device(),
dtype=dtype))
self.bias = Parameter(
torch.zeros(self.normalized_shape_per_partition,
device=get_current_device(),
dtype=dtype))
self.variance_epsilon = eps
self._set_tensor_parallel_attributes()
def _set_tensor_parallel_attributes(self):
set_tensor_parallel_attribute(self.weight)
set_tensor_parallel_attribute(self.bias)
def groups_for_next_layer(self) -> Tuple[ParallelMode, ParallelMode]:
return self.input_parallel_mode, self.weight_parallel_mode
def reset_parameters(self):
nn.init.zeros_(self.bias)
nn.init.ones_(self.weight)
def forward(self, input_: Tensor) -> Tensor:
'''x = weight * (x - mean) / sqrt(var + eps) + bias'''
# input: [m/q^2, n, h/q]
# [m/q^2, n, 1]
mean = Sum_3D.apply(input_, -1, self.depth, self.output_parallel_mode,
True) / self.normalized_shape
# [m/q^2, n, 1]
var = (input_ - mean).pow(2)
var = Sum_3D.apply(var, -1, self.depth, self.output_parallel_mode,
True) / self.normalized_shape
output = (input_ - mean) / torch.sqrt(var + self.variance_epsilon)
output = Mul_3D.apply(output, self.weight, self.depth,
self.input_parallel_mode,
self.weight_parallel_mode,
self.output_parallel_mode)
output = Add_3D.apply(output, self.bias, self.depth,
self.input_parallel_mode,
self.weight_parallel_mode,
self.output_parallel_mode)
return output
def extra_repr(self):
return '{}, eps={}'.format(self.normalized_shape,
self.variance_epsilon)
@LAYERS.register_module
class Linear3D(nn.Module):
def __init__(self,
in_features: int,
out_features: int,
input_parallel_mode: ParallelMode,
weight_parallel_mode: ParallelMode,
bias: bool = True,
dtype: dtype = None):
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.input_parallel_mode = input_parallel_mode
self.weight_parallel_mode = weight_parallel_mode
self.output_parallel_mode = get_last_group(self.input_parallel_mode,
self.weight_parallel_mode)
self.with_bias = bias
self.depth = get_depth_from_env()
self.in_features_per_partition = divide(in_features, self.depth)
self.out_features_per_partition = divide(out_features, self.depth**2)
# [k/q, h/q^2]
self.weight = Parameter(
torch.empty(self.in_features_per_partition,
self.out_features_per_partition,
device=get_current_device(),
dtype=dtype))
# [h/q^2]
if bias:
self.bias = Parameter(
torch.zeros(self.out_features_per_partition,
device=get_current_device(),
dtype=dtype))
else:
self.register_parameter('bias', None)
self.reset_parameters()
self._set_tensor_parallel_attributes()
def _set_tensor_parallel_attributes(self):
set_tensor_parallel_attribute(self.weight)
if self.bias is not None:
set_tensor_parallel_attribute(self.bias)
def groups_for_next_layer(self) -> Tuple[ParallelMode, ParallelMode]:
return self.output_parallel_mode, self.weight_parallel_mode
def reset_parameters(self):
# setting
fan_in = self.in_features
a = math.sqrt(5)
nonlinearity = 'leaky_relu'
# init weight
std = nn.init.calculate_gain(nonlinearity, a) / math.sqrt(fan_in)
bound = math.sqrt(3.0) * std
with seed(ParallelMode.TENSOR):
nn.init.uniform_(self.weight, -bound, bound)
# init bias
if self.with_bias:
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
with seed(ParallelMode.TENSOR):
nn.init.uniform_(self.bias, -bound, bound)
def forward(self, input_: Tensor) -> Tensor:
# input: [m/q^2, n, k/q]
# output: [m/q^2, n, h/q]
output = Matmul_AB_3D.apply(input_, self.weight, self.depth,
self.input_parallel_mode,
self.weight_parallel_mode,
self.output_parallel_mode)
if self.with_bias:
output = Add_3D.apply(output, self.bias, self.depth,
self.output_parallel_mode,
self.weight_parallel_mode,
self.input_parallel_mode)
return output
def extra_repr(self):
return 'in_features={}, out_features={}, bias={}'.format(
self.in_features, self.out_features, self.with_bias)

4
colossalai/nn/layer/parallel_sequence/__init__.py

@ -0,0 +1,4 @@
from ._operation import RingQK, RingAV
from .layers import TransformerSelfAttentionRing
__all__ = ['TransformerSelfAttentionRing', 'RingAV', 'RingQK']

169
colossalai/nn/layer/parallel_sequence/_operation.py

@ -0,0 +1,169 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import torch
from torch import distributed as dist
from colossalai.communication import ring_forward
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.nn.layer.parallel_sequence._utils import _calc_incoming_device_range, _calc_current_device_range
from colossalai.utils import get_current_device
class RingQK(torch.autograd.Function):
"""
Calculate QK in a ring-exchange style
"""
@staticmethod
def forward(ctx,
sub_q,
sub_k,
batch_size,
num_attention_heads,
sub_seq_length):
# save tensor for backward
ctx.save_for_backward(sub_q, sub_k)
ctx.sub_seq_length = sub_seq_length
# create local segment of attention score
attention_score = torch.empty(
batch_size * num_attention_heads,
sub_seq_length,
sub_seq_length * gpc.get_world_size(ParallelMode.SEQUENCE),
dtype=sub_q.dtype,
device=get_current_device()
)
# compute local QK^T
part_a = torch.matmul(sub_q, sub_k.transpose(2, 1))
local_rank = gpc.get_local_rank(ParallelMode.SEQUENCE)
local_world_size = gpc.get_world_size(ParallelMode.SEQUENCE)
start_idx = local_rank * sub_seq_length
end_idx = (local_rank + 1) * sub_seq_length
attention_score[:, :, start_idx: end_idx] = part_a
# compute QK^T in ring-all-reduce style
for i in range(local_world_size - 1):
sub_k = ring_forward(sub_k, ParallelMode.SEQUENCE)
start_idx, end_idx = _calc_incoming_device_range(i, local_rank, local_world_size, sub_seq_length)
part_a = torch.matmul(sub_q, sub_k.transpose(2, 1))
attention_score[:, :, start_idx:end_idx] = part_a
return attention_score
@staticmethod
def backward(ctx, grad_output):
sub_q, sub_k, = ctx.saved_tensors
local_rank = gpc.get_local_rank(ParallelMode.SEQUENCE)
local_world_size = gpc.get_world_size(ParallelMode.SEQUENCE)
# calculate gradient of sub_k
grad_k = torch.matmul(
grad_output.transpose(2, 1),
sub_q
)
dist.all_reduce(grad_k, group=gpc.get_group(ParallelMode.SEQUENCE))
grad_k = grad_k[:, local_rank * ctx.sub_seq_length: (local_rank + 1) * ctx.sub_seq_length]
grad_k /= local_world_size
# calculate gradient for sub_q
grad_q = torch.zeros_like(sub_q,
dtype=sub_q.dtype,
device=get_current_device(), )
# compute with local sub_k
start_idx, end_idx = _calc_current_device_range(local_rank, ctx.sub_seq_length)
grad_q += torch.matmul(grad_output[:, :, start_idx:end_idx], sub_k)
# compute QK^T in ring-all-reduce style
for i in range(local_world_size - 1):
sub_k = ring_forward(sub_k, ParallelMode.SEQUENCE)
start_idx, end_idx = _calc_incoming_device_range(i, local_rank, local_world_size, ctx.sub_seq_length)
grad_q += torch.matmul(grad_output[:, :, start_idx: end_idx], sub_k)
grad_q /= local_world_size
return grad_q, grad_k, None, None, None
class RingAV(torch.autograd.Function):
"""
Calculate AV in a ring-exchange style
"""
@staticmethod
def forward(ctx,
attention_score,
sub_v,
batch_size,
num_attention_heads,
attention_head_size,
sub_seq_length):
local_rank = gpc.get_local_rank(ParallelMode.SEQUENCE)
local_world_size = gpc.get_world_size(ParallelMode.SEQUENCE)
local_start_idx, local_end_idx = _calc_current_device_range(local_rank, sub_seq_length)
sub_attention_result = torch.zeros(
batch_size * num_attention_heads,
sub_seq_length,
attention_head_size,
device=get_current_device(),
dtype=attention_score.dtype)
# save tensors for backward
ctx.save_for_backward(attention_score, sub_v)
ctx.sub_seq_length = sub_seq_length
# compute local AV
part_av = torch.matmul(attention_score[:, :, local_start_idx:local_end_idx], sub_v)
sub_attention_result += part_av
# compute AV in ring - all - reduce style
for i in range(local_world_size - 1):
sub_v = ring_forward(sub_v, ParallelMode.SEQUENCE)
start_idx, end_idx = _calc_incoming_device_range(i, local_rank, local_world_size, sub_seq_length)
# compute QK^T
part_av = torch.matmul(attention_score[:, :, start_idx:end_idx], sub_v)
sub_attention_result += part_av
return sub_attention_result
@staticmethod
def backward(ctx, grad_output):
local_rank = gpc.get_local_rank(ParallelMode.SEQUENCE)
local_world_size = gpc.get_world_size(ParallelMode.SEQUENCE)
local_start_idx, local_end_idx = _calc_current_device_range(local_rank, ctx.sub_seq_length)
attention_scores, sub_v = ctx.saved_tensors
# calculate gradient of v
grad_v = torch.matmul(
attention_scores.transpose(2, 1),
grad_output
)
dist.all_reduce(grad_v, group=gpc.get_group(ParallelMode.SEQUENCE))
grad_v = grad_v[:, local_start_idx:local_end_idx]
grad_v /= local_world_size
# calculate gradient for attention score
grad_attention_score = torch.zeros_like(attention_scores,
dtype=grad_output.dtype,
device=get_current_device())
# compute with local sub_k
grad_attention_score[:, :, local_start_idx:local_end_idx] += torch.matmul(
grad_output,
sub_v.transpose(2, 1))
# compute QK^T in ring-all-reduce style
for i in range(local_world_size - 1):
sub_v = ring_forward(sub_v, ParallelMode.SEQUENCE)
start_idx, end_idx = _calc_incoming_device_range(i, local_rank, local_world_size, ctx.sub_seq_length)
# compute grad_q
grad_attention_score[:, :, start_idx:end_idx] += torch.matmul(
grad_output,
sub_v.transpose(2, 1))
return grad_attention_score, grad_v, None, None, None, None

15
colossalai/nn/layer/parallel_sequence/_utils.py

@ -0,0 +1,15 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
def _calc_incoming_device_range(i, rank, world_size, sub_seq_length):
device_of_incoming_k = (rank - i - 1) % world_size
start_idx = sub_seq_length * device_of_incoming_k
end_idx = sub_seq_length * (device_of_incoming_k + 1)
return start_idx, end_idx
def _calc_current_device_range(rank, sub_seq_length):
start_idx = sub_seq_length * rank
end_idx = sub_seq_length * (rank + 1)
return start_idx, end_idx

188
colossalai/nn/layer/parallel_sequence/layers.py

@ -0,0 +1,188 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.nn.layer.parallel_sequence._operation import RingQK, RingAV
from colossalai.registry import LAYERS
@LAYERS.register_module
class TransformerSelfAttentionRing(nn.Module):
"""Parallel self-attention layer abstract class.
Self-attention layer takes input with size [b, s, h]
and returns output of the same size.
:param hidden_size: hidden size
:type hidden_size: int
:param kv_channels: channels of key/value tensor
:type kv_channels: int
:param num_attention_heads: number of attention heads
:type num_attention_heads: int
:param attention_dropout: dropout probability for attention layer
:type attention_dropout: float
"""
def __init__(self,
hidden_size,
kv_channels,
num_attention_heads,
attention_dropout,
):
super().__init__()
self.hidden_size = hidden_size
self.num_attention_heads = num_attention_heads
projection_size = kv_channels * num_attention_heads
self.hidden_size_per_attention_head = projection_size // num_attention_heads
self.world_size = gpc.get_world_size(ParallelMode.SEQUENCE)
# Strided linear layer.
self.query_key_value = nn.Linear(
hidden_size,
3 * projection_size,
)
# coeff = None
self.norm_factor = math.sqrt(self.hidden_size)
# TODO: add apply_query_key_layer_scaling when we have the kernel module
# if self.apply_query_key_layer_scaling:
# coeff = self.layer_number
# self.norm_factor *= coeff
# TODO: add fused scale mask softmax kernel when we have the kernel module
# self.scale_mask_softmax = FusedScaleMaskSoftmax(
# self.fp16, self.bf16,
# self.attn_mask_type,
# masked_softmax_fusion,
# attention_mask_func,
# self.attention_softmax_in_fp32,
# coeff)
self.attention_dropout = nn.Dropout(attention_dropout)
# Output.
self.dense = nn.Linear(
projection_size,
hidden_size,
bias=True)
def forward(self, hidden_states, attention_mask):
# hidden_states: [sq, b, h]
sub_seq_length, batch_size, hidden_size = hidden_states.size()
# =====================
# Query, Key, and Value
# =====================
# Attention heads [sq, b, h] --> [sq, b, (3 * hn * num_heads)]
mixed_x_layer = self.query_key_value(hidden_states)
# [sq, b, num_heads, 3 * hn] --> 3 [sq, b, num_heads, hn]
new_tensor_shape = mixed_x_layer.size()[:-1] + (self.num_attention_heads,
3 * self.hidden_size_per_attention_head)
mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)
# split into query, key and value
last_dim = mixed_x_layer.dim() - 1
last_dim_value = mixed_x_layer.size()[-1]
assert last_dim_value % 3 == 0, 'the last dimension is not a multiple of 3, ' \
'cannot be divided into query, key and value'
partition_size = last_dim_value // 3
(query_layer, key_layer, value_layer) = torch.split(
mixed_x_layer, partition_size, dim=last_dim)
# ===================================
# Raw attention scores. [b, num_heads, s, s]
# ===================================
# [b, num_heads, sq, sk]
output_size = (query_layer.size(1),
query_layer.size(2),
query_layer.size(0),
key_layer.size(0) * self.world_size)
# [sq, b, num_heads, hn] -> [sq, b * num_heads, hn]
query_layer = query_layer.view(output_size[2],
output_size[0] * output_size[1], -1)
# [sk, b, num_heads, hn] -> [sk, b * num_heads, hn]
key_layer = key_layer.view(key_layer.size(0),
output_size[0] * output_size[1], -1)
# [b, sq, sk]
attention_scores = RingQK.apply(
# [b * num_heads, sq, hn]
query_layer.transpose(0, 1).contiguous(),
key_layer.transpose(0, 1).contiguous(), # [b * num_heads, sk, hn],
batch_size,
self.num_attention_heads,
sub_seq_length
)
attention_scores /= self.norm_factor
# change view to [b, num_heads, sq, sk]
attention_scores = attention_scores.view(*output_size)
attention_scores = attention_scores.unsqueeze(1)
attention_scores = attention_scores + attention_mask
attention_probs = F.softmax(attention_scores, dim=-1)
attention_probs = attention_probs.squeeze(1)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
# with mpu.get_cuda_rng_tracker().fork():
# TODO: check if a rng tracker is needed
attention_probs = self.attention_dropout(attention_probs)
# context layer shape: [b, num_heads, sq, hn]
output_size = (value_layer.size(1),
value_layer.size(2),
query_layer.size(0),
value_layer.size(3))
#
# # change view [sk, b * num_heads, hn]
value_layer = value_layer.contiguous().view(value_layer.size(0),
output_size[0] * output_size[1], -1)
# # change view [b * num_heads, sq, sk]
attention_probs = attention_probs.view(attention_probs.size(0) * attention_probs.size(1),
attention_probs.size(2),
attention_probs.size(3))
# matmul: [b*num_heads, sq, hn]
# context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1))
context_layer = RingAV.apply(
attention_probs,
value_layer.transpose(0, 1).contiguous(),
batch_size,
self.num_attention_heads,
self.hidden_size_per_attention_head,
sub_seq_length
)
# # change view [b, num_heads, sq, hn]
context_layer = context_layer.view(*output_size)
# # [b, np, sq, hn] --> [sq, b, np, hn]
context_layer = context_layer.permute(2, 0, 1, 3).contiguous()
# # [sq, b, np, hn] --> [sq, b, hp]
new_context_layer_shape = context_layer.size()[:-2] + (
self.hidden_size_per_attention_head * self.num_attention_heads,)
context_layer = context_layer.view(*new_context_layer_shape)
# context_layer = context_layer.transpose(1, 0).contiguous()
output = self.dense(context_layer)
bias = self.dense.bias
return output, bias

3
colossalai/nn/layer/parallel_vision_transformer/__init__.py

@ -0,0 +1,3 @@
from .layers import ViTBlock
__all__ = ['ViTBlock']

59
colossalai/nn/layer/parallel_vision_transformer/layers.py

@ -0,0 +1,59 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from torch import nn as nn
from colossalai.builder import build_layer
from colossalai.registry import LAYERS
@LAYERS.register_module
class ViTBlock(nn.Module):
"""Vision Transformer block
:param attention_cfg: config of attention layer
:type attention_cfg: dict
:param droppath_cfg: config of drop path
:type droppath_cfg: dict
:param mlp_cfg: config of MLP layer
:type mlp_cfg: dict
:param norm_cfg: config of normlization layer
:type norm_cfg: dict
"""
def __init__(self,
attention_cfg: dict,
droppath_cfg: dict,
mlp_cfg: dict,
norm_cfg: dict,
):
super().__init__()
self.norm1 = build_layer(norm_cfg)
self.attn = build_layer(attention_cfg)
self.drop_path = build_layer(
droppath_cfg) if droppath_cfg['drop_path'] > 0. else nn.Identity()
self.norm2 = build_layer(norm_cfg)
self.mlp = build_layer(mlp_cfg)
def forward(self, x):
x = x + self.drop_path(self.attn(self.norm1(x)))
x = x + self.drop_path(self.mlp(self.norm2(x)))
# x_ = x
# x_ = self.norm1(x_)
# if self.checkpoint:
# x_ = checkpoint(self.attn, x_)
# else:
# x_ = self.attn(x_)
# x_ = self.drop_path(x_)
# x = x + x_
#
# x_ = x
# x_ = self.norm2(x_)
# if self.checkpoint:
# x_ = checkpoint(self.mlp, x_)
# else:
# x_ = self.mlp(x_)
# x_ = self.drop_path(x_)
# x = x + x_
return x

5
colossalai/nn/layer/vanilla_resnet/__init__.py

@ -0,0 +1,5 @@
from .basic_block import ResNetBasicBlock
from .bottleneck import ResNetBottleneck
from .reslayer import ResLayer
__all__ = ['ResLayer', 'ResNetBottleneck', 'ResNetBasicBlock']

64
colossalai/nn/layer/vanilla_resnet/basic_block.py

@ -0,0 +1,64 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from typing import Optional, Callable
import torch.nn as nn
from torch import Tensor
from colossalai.registry import LAYERS
from .conv import conv3x3
@LAYERS.register_module
class ResNetBasicBlock(nn.Module):
"""Basic ResNet block
"""
expansion: int = 1
def __init__(
self,
inplanes: int,
planes: int,
stride: int = 1,
downsample: Optional[nn.Module] = None,
groups: int = 1,
base_width: int = 64,
dilation: int = 1,
norm_layer: Optional[Callable[..., nn.Module]] = None
) -> None:
super().__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
if groups != 1 or base_width != 64:
raise ValueError(
'BasicBlock only supports groups=1 and base_width=64')
if dilation > 1:
raise NotImplementedError(
"Dilation > 1 not supported in BasicBlock")
# Both self.conv1 and self.downsample layers downsample the input when stride != 1
self.conv1 = conv3x3(inplanes, planes, stride)
self.bn1 = norm_layer(planes)
self.relu = nn.ReLU(inplace=True)
self.conv2 = conv3x3(planes, planes)
self.bn2 = norm_layer(planes)
self.downsample = downsample
self.stride = stride
def forward(self, x: Tensor) -> Tensor:
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu(out)
return out

69
colossalai/nn/layer/vanilla_resnet/bottleneck.py

@ -0,0 +1,69 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from typing import Optional, Callable
import torch.nn as nn
from torch import Tensor
from colossalai.registry import LAYERS
from .conv import conv3x3, conv1x1
@LAYERS.register_module
class ResNetBottleneck(nn.Module):
# Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
# while original implementation places the stride at the first 1x1 convolution(self.conv1)
# according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
# This variant is also known as ResNet V1.5 and improves accuracy according to
# https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.
expansion: int = 4
def __init__(
self,
inplanes: int,
planes: int,
stride: int = 1,
downsample: Optional[nn.Module] = None,
groups: int = 1,
base_width: int = 64,
dilation: int = 1,
norm_layer: Optional[Callable[..., nn.Module]] = None
) -> None:
super().__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
width = int(planes * (base_width / 64.)) * groups
# Both self.conv2 and self.downsample layers downsample the input when stride != 1
self.conv1 = conv1x1(inplanes, width)
self.bn1 = norm_layer(width)
self.conv2 = conv3x3(width, width, stride, groups, dilation)
self.bn2 = norm_layer(width)
self.conv3 = conv1x1(width, planes * self.expansion)
self.bn3 = norm_layer(planes * self.expansion)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride
def forward(self, x: Tensor) -> Tensor:
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu(out)
return out

15
colossalai/nn/layer/vanilla_resnet/conv.py

@ -0,0 +1,15 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import torch.nn as nn
def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d:
"""3x3 convolution with padding"""
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
padding=dilation, groups=groups, bias=False, dilation=dilation)
def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d:
"""1x1 convolution"""
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)

63
colossalai/nn/layer/vanilla_resnet/reslayer.py

@ -0,0 +1,63 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import torch.nn as nn
from colossalai.registry import LAYERS
from .conv import conv1x1
@LAYERS.register_module
class ResLayer(nn.Module):
def __init__(self,
block_type: str,
norm_layer_type: str,
inplanes: int,
planes: int,
blocks: int,
groups: int,
base_width: int,
stride: int = 1,
dilation: int = 1,
dilate: bool = False,
):
super().__init__()
self.block = LAYERS.get_module(block_type)
self.norm_layer = LAYERS.get_module(norm_layer_type)
self.inplanes = inplanes
self.planes = planes
self.blocks = blocks
self.groups = groups
self.dilation = dilation
self.base_width = base_width
self.dilate = dilate
self.stride = stride
self.layer = self._make_layer()
def _make_layer(self):
norm_layer = self.norm_layer
downsample = None
previous_dilation = self.dilation
if self.dilate:
self.dilation *= self.stride
self.stride = 1
if self.stride != 1 or self.inplanes != self.planes * self.block.expansion:
downsample = nn.Sequential(
conv1x1(self.inplanes, self.planes * self.block.expansion, self.stride),
norm_layer(self.planes * self.block.expansion),
)
layers = []
layers.append(self.block(self.inplanes, self.planes, self.stride, downsample, self.groups,
self.base_width, previous_dilation, norm_layer))
self.inplanes = self.planes * self.block.expansion
for _ in range(1, self.blocks):
layers.append(self.block(self.inplanes, self.planes, groups=self.groups,
base_width=self.base_width, dilation=self.dilation,
norm_layer=norm_layer))
return nn.Sequential(*layers)
def forward(self, x):
return self.layer(x)

7
colossalai/nn/layer/vanilla_vision_transformer/__init__.py

@ -0,0 +1,7 @@
from .layers import (VanillaViTBlock, VanillaViTMLP, VanillaViTPatchEmbedding,
VanillaViTAttention, VanillaViTDropPath, VanillaViTHead)
__all__ = [
'VanillaViTBlock', 'VanillaViTMLP', 'VanillaViTPatchEmbedding',
'VanillaViTAttention', 'VanillaViTDropPath', 'VanillaViTHead'
]

244
colossalai/nn/layer/vanilla_vision_transformer/layers.py

@ -0,0 +1,244 @@
import collections.abc
from itertools import repeat
import torch
from torch import nn as nn
from colossalai.registry import LAYERS
# From PyTorch internals
def _ntuple(n):
def parse(x):
if isinstance(x, collections.abc.Iterable):
return x
return tuple(repeat(x, n))
return parse
to_2tuple = _ntuple(2)
@LAYERS.register_module
class VanillaViTPatchEmbedding(nn.Module):
""" 2D Image to Patch Embedding
"""
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None, flatten=True, drop=0.):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
self.img_size = img_size
self.patch_size = patch_size
self.grid_size = (img_size[0] // patch_size[0],
img_size[1] // patch_size[1])
self.num_patches = self.grid_size[0] * self.grid_size[1]
self.flatten = flatten
self.proj = nn.Conv2d(in_chans, embed_dim,
kernel_size=patch_size, stride=patch_size)
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
self.pos_embed = nn.Parameter(torch.zeros(1, self.num_patches + 1, embed_dim))
self.pos_drop = nn.Dropout(p=drop)
def forward(self, x):
B, C, H, W = x.shape
assert H == self.img_size[0] and W == self.img_size[1], \
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
x = self.proj(x)
if self.flatten:
x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
x = self.norm(x)
cls_token = self.cls_token.expand(x.shape[0], -1, -1)
x = torch.cat((cls_token, x), dim=1)
x = self.pos_drop(x + self.pos_embed)
return x
@LAYERS.register_module
class VanillaViTMLP(nn.Module):
""" MLP as used in Vision Transformer, MLP-Mixer and related networks
"""
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
def drop_path(x, drop_prob: float = 0., training: bool = False):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
'survival rate' as the argument.
"""
if drop_prob == 0. or not training:
return x
keep_prob = 1 - drop_prob
# work with diff dim tensors, not just 2D ConvNets
shape = (x.shape[0],) + (1,) * (x.ndim - 1)
random_tensor = keep_prob + \
torch.rand(shape, dtype=x.dtype, device=x.device)
random_tensor.floor_() # binarize
output = x.div(keep_prob) * random_tensor
return output
@LAYERS.register_module
class VanillaViTDropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
"""
def __init__(self, drop_prob=0.):
super().__init__()
self.drop_prob = drop_prob
def forward(self, x):
return drop_path(x, self.drop_prob, self.training)
@LAYERS.register_module
class VanillaViTAttention(nn.Module):
"""Vanilla attention layer of Vision Transformer
:param dim: dimension of input tensor
:type dim: int
:param num_heads: number of attention heads, defaults to 8
:type num_heads: int, optional
:param qkv_bias: enable bias for qkv if True, defaults to False
:type qkv_bias: bool, optional
:param attn_drop: dropout probability for attention layer, defaults to 0.
:type attn_drop: float, optional
:param proj_drop: dropout probability for linear layer, defaults to 0.
:type proj_drop: float, optional
"""
def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = head_dim ** -0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x):
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C //
self.num_heads).permute(2, 0, 3, 1, 4)
# make torchscript happy (cannot use tensor as tuple)
q, k, v = qkv[0], qkv[1], qkv[2]
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
@LAYERS.register_module
class VanillaViTBlock(nn.Module):
"""Vanilla Vision Transformer block
:param dim: dimension of input tensor
:type dim: int
:param num_heads: number of attention heads
:type num_heads: int
:param mlp_ratio: hidden size of MLP divided by embedding dim, defaults to 4.
:type mlp_ratio: float, optional
:param qkv_bias: enable bias for qkv if True, defaults to False
:type qkv_bias: bool, optional
:param drop: dropout probability, defaults to 0.
:type drop: float, optional
:param attn_drop: dropout probability for attention layer, defaults to 0.
:type attn_drop: float, optional
:param drop_path: drop path probability, defaults to 0.
:type drop_path: float, optional
:param act_layer: activation function, defaults to nn.GELU
:type act_layer: torch.nn.Module, optional
:param norm_layer: normalization layer, defaults to nn.LayerNorm
:type norm_layer: torch.nn.Module, optional
"""
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0.,
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = LAYERS.get_module('VanillaViTAttention')(dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
attn_drop=attn_drop,
proj_drop=drop)
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
self.drop_path = LAYERS.get_module('VanillaViTDropPath')(
drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = LAYERS.get_module('VanillaViTMLP')(in_features=dim,
hidden_features=mlp_hidden_dim,
act_layer=act_layer,
drop=drop)
def forward(self, x):
x = x + self.drop_path(self.attn(self.norm1(x)))
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
@LAYERS.register_module
class VanillaViTHead(nn.Module):
"""Output layer of vanilla Vision Transformer
:param in_features: size of input tensor
:type in_features: int
:param intermediate_features: hidden size
:type intermediate_features: int
:param out_features: size of output tensor
:type out_features: int
:param bias: whether to add bias, defaults to True
:type bias: bool, optional
"""
def __init__(self,
in_features,
intermediate_features,
out_features,
bias=True
):
super().__init__()
self.linear_1 = nn.Linear(
in_features, intermediate_features, bias=bias)
self.act = nn.Tanh()
self.linear_2 = nn.Linear(
intermediate_features, out_features, bias=bias)
def forward(self, x):
x = x[:, 0, :].squeeze(1)
x = self.linear_1(x)
x = self.act(x)
x = self.linear_2(x)
return x

3
colossalai/nn/layer/wrapper/__init__.py

@ -0,0 +1,3 @@
from .lambda_wrapper import LambdaWrapper
__all__ = ['LambdaWrapper']

37
colossalai/nn/layer/wrapper/lambda_wrapper.py

@ -0,0 +1,37 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import torch.nn as nn
from colossalai.builder import build_layer
from colossalai.registry import LAYERS
@LAYERS.register_module
class LambdaWrapper(nn.Module):
"""Wrap a function to nn.Module, which takes a config of layers and can fully access them
:param func: user customed function
:type func: Callable
:param layers_cfg: config of layers, defaults to None
:type layers_cfg: dict, optional
"""
def __init__(self, func, layers_cfg: dict = None):
super().__init__()
self.func = func
self.layers = self._build_layers(layers_cfg)
def _build_layers(self, layers_cfg: dict):
if layers_cfg is None:
return None
else:
layers = []
for cfg in layers_cfg:
layer = build_layer(cfg)
layers.append(layer)
return layers
def forward(self, *args, **kwargs):
return self.func(self, *args, **kwargs)

6
colossalai/nn/loss/__init__.py

@ -0,0 +1,6 @@
from .base_loss import BaseLoss
from .cross_entropy_2d import CrossEntropyLoss2D
from .cross_entropy_2p5d import CrossEntropyLoss2p5D
from .cross_entropy_3d import CrossEntropyLoss3D
__all__ = ['CrossEntropyLoss2D', 'CrossEntropyLoss2p5D', 'CrossEntropyLoss3D']

13
colossalai/nn/loss/base_loss.py

@ -0,0 +1,13 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from abc import ABC, abstractmethod
class BaseLoss(ABC):
"""Absctract loss class
"""
@abstractmethod
def calc_loss(self, *args, **kwargs):
pass

120
colossalai/nn/loss/cross_entropy_1d.py

@ -0,0 +1,120 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import torch
import torch.nn.functional as F
from torch.nn.modules.loss import _Loss
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.nn.layer.parallel_1d._utils import vocab_range_from_per_partition_vocab_size
class _VocabParallelCrossEntropy_1D(torch.autograd.Function):
@staticmethod
def forward(ctx, vocab_parallel_logits, target):
# Maximum value along vocab dimension across all GPUs.
logits_max = torch.max(vocab_parallel_logits, dim=-1)[0]
torch.distributed.all_reduce(logits_max,
op=torch.distributed.ReduceOp.MAX,
group=gpc.get_group(ParallelMode.PARALLEL_1D))
# Subtract the maximum value.
vocab_parallel_logits.sub_(logits_max.unsqueeze(dim=-1))
# Get the partition's vocab indecies
partition_vocab_size = vocab_parallel_logits.size()[-1]
rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
world_size = gpc.get_world_size(ParallelMode.PARALLEL_1D)
vocab_start_index, vocab_end_index = vocab_range_from_per_partition_vocab_size(
partition_vocab_size, rank, world_size)
# Create a mask of valid vocab ids (1 means it needs to be masked).
target_mask = (target < vocab_start_index) | (target >= vocab_end_index)
masked_target = target.clone() - vocab_start_index
masked_target[target_mask] = 0
# Get predicted-logits = logits[target].
# For Simplicity, we convert logits to a 2-D tensor with size
# [*, partition-vocab-size] and target to a 1-D tensor of size [*].
logits_2d = vocab_parallel_logits.view(-1, partition_vocab_size)
masked_target_1d = masked_target.view(-1)
arange_1d = torch.arange(start=0, end=logits_2d.size()[0],
device=logits_2d.device)
predicted_logits_1d = logits_2d[arange_1d, masked_target_1d]
predicted_logits_1d = predicted_logits_1d.clone().contiguous()
predicted_logits = predicted_logits_1d.view_as(target)
predicted_logits[target_mask] = 0.0
# All reduce is needed to get the chunks from other GPUs.
torch.distributed.all_reduce(predicted_logits,
op=torch.distributed.ReduceOp.SUM,
group=gpc.get_group(ParallelMode.PARALLEL_1D))
# Sum of exponential of logits along vocab dimension across all GPUs.
exp_logits = vocab_parallel_logits
torch.exp(vocab_parallel_logits, out=exp_logits)
sum_exp_logits = exp_logits.sum(dim=-1)
torch.distributed.all_reduce(sum_exp_logits,
op=torch.distributed.ReduceOp.SUM,
group=gpc.get_group(ParallelMode.PARALLEL_1D))
# Loss = log(sum(exp(logits))) - predicted-logit.
loss = torch.log(sum_exp_logits) - predicted_logits
# Store softmax, target-mask and masked-target for backward pass.
exp_logits.div_(sum_exp_logits.unsqueeze(dim=-1))
ctx.save_for_backward(exp_logits, target_mask, masked_target_1d)
return loss
@staticmethod
def backward(ctx, grad_output):
# Retreive tensors from the forward path.
softmax, target_mask, masked_target_1d = ctx.saved_tensors
# All the inputs have softmax as thier gradient.
grad_input = softmax
# For simplicity, work with the 2D gradient.
partition_vocab_size = softmax.size()[-1]
grad_2d = grad_input.view(-1, partition_vocab_size)
# Add the gradient from matching classes.
arange_1d = torch.arange(start=0, end=grad_2d.size()[0],
device=grad_2d.device)
grad_2d[arange_1d, masked_target_1d] -= (
1.0 - target_mask.view(-1).float())
# Finally elementwise multiplication with the output gradients.
grad_input.mul_(grad_output.unsqueeze(dim=-1))
return grad_input, None
class LmLoss1D(_Loss):
def forward(self, lm_logits, lm_labels, loss_mask):
lm_loss = _VocabParallelCrossEntropy_1D.apply(lm_logits, lm_labels)
lm_loss = torch.sum(
lm_loss.view(-1) * loss_mask.reshape(-1)) / loss_mask.sum()
return lm_loss
class SopLoss1D(_Loss):
def forward(self, sop_logits, sentence_order):
sop_loss = F.cross_entropy(sop_logits.view(-1, 2).float(),
sentence_order.view(-1),
ignore_index=-1)
return sop_loss
class BERTDualHeadLoss(_Loss):
def __init__(self):
self.lm_loss = LmLoss1D()
self.sop_loss = SopLoss1D()
def forward(self, lm_logits, sop_logits, lm_labels, loss_mask, sentence_order):
lm_loss = self.lm_loss(lm_logits, lm_labels, loss_mask)
sop_loss = self.sop_loss(sop_logits, sentence_order)
return lm_loss + sop_loss

128
colossalai/nn/loss/cross_entropy_2d.py

@ -0,0 +1,128 @@
import torch
import torch.distributed as dist
from torch.nn.modules.loss import _Loss
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.nn.layer.parallel_2d._utils import assert_summa_initialization, get_summa_dim_from_env
from colossalai.registry import LOSSES
from colossalai.utils import get_current_device
class _ParallelCrossEntropyLossFunction_2D(torch.autograd.Function):
### Modified based on megatron.mpu.cross_entropy ###
@staticmethod
def forward(ctx, logits, targets):
# logits: [b/q, h/q]
# labels: [b/q]
# loss: [b/q]
# vocab_parallel_logits: [b/q, s, v/q]
# target: [b/q, s]
logits_max = torch.max(logits, dim=-1)[0]
torch.distributed.all_reduce(
logits_max,
op=torch.distributed.ReduceOp.MAX,
group=gpc.get_group(ParallelMode.PARALLEL_2D_ROW))
# Subtract the maximum value.
# vocab_parallel_logits.sub_(logits_max.unsqueeze(dim=-1))
logits = logits - logits_max.unsqueeze(dim=-1)
vocab_size = logits.size(-1)
rank = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW)
vocab_start = rank * (vocab_size)
vocab_end = (rank + 1) * (vocab_size) - 1
target_mask = (targets < vocab_start) | (targets > vocab_end)
masked_target = targets.clone() - vocab_start
masked_target[target_mask] = 0
arange_1d = torch.arange(
start=0, end=logits.size()[0],
)
predicted_logits = logits[arange_1d, masked_target]
predicted_logits[target_mask] = 0.
dist.all_reduce(predicted_logits, group=gpc.get_group(
ParallelMode.PARALLEL_2D_ROW))
exp_logits = torch.exp(logits)
sum_exp_logits = exp_logits.sum(dim=1)
dist.all_reduce(sum_exp_logits, group=gpc.get_group(
ParallelMode.PARALLEL_2D_ROW))
loss = torch.log(sum_exp_logits) - predicted_logits
exp_logits.div_(sum_exp_logits.unsqueeze(dim=-1))
ctx.save_for_backward(exp_logits, target_mask, masked_target)
return loss
@staticmethod
def backward(ctx, output_grad):
# Retreive tensors from the forward path.
softmax, target_mask, masked_target = ctx.saved_tensors
# All the inputs have softmax as their gradient.
grad_input = softmax
# For simplicity, work with the 2D gradient.
partition_vocab_size = softmax.size()[-1]
grad_2d = grad_input.view(-1, partition_vocab_size)
# Add the gradient from matching classes.
arange_1d = torch.arange(start=0, end=grad_2d.size()[0],
device=get_current_device())
grad_2d[arange_1d,
masked_target] -= (1.0 - target_mask.view(-1).float())
# Finally elementwise multiplication with the output gradients.
grad_input.mul_(output_grad.unsqueeze(dim=-1))
return grad_input, None
class _ReduceByColumn(torch.autograd.Function):
"""All-reduce the input from the model parallel region."""
@staticmethod
def symbolic(graph, input_):
dist.all_reduce(input_, group=gpc.get_group(
ParallelMode.PARALLEL_2D_COL))
return input_
@staticmethod
def forward(ctx, input_):
dist.all_reduce(input_, group=gpc.get_group(
ParallelMode.PARALLEL_2D_COL))
return input_
@staticmethod
def backward(ctx, grad_output):
return grad_output
@LOSSES.register_module
class CrossEntropyLoss2D(_Loss):
"""Cross entropy loss for 2D parallelism
:param reduction: whether to average the loss, defaults to True
:type reduction: bool, optional
"""
def __init__(self, reduction=True):
super().__init__()
assert_summa_initialization()
self.summa_dim = get_summa_dim_from_env()
self.row_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL)
self.reduction_mean = reduction
def forward(self, logits, targets):
targets = targets.chunk(self.summa_dim, dim=0)[self.row_rank]
loss = _ParallelCrossEntropyLossFunction_2D.apply(
logits, targets,
)
if self.reduction_mean:
loss = _ReduceByColumn.apply(loss) / self.summa_dim
dist_loss = loss.mean()
return dist_loss

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save