mirror of https://github.com/hpcaitech/ColossalAI
zbian
3 years ago
409 changed files with 35853 additions and 0 deletions
@ -0,0 +1,144 @@
|
||||
# Byte-compiled / optimized / DLL files |
||||
__pycache__/ |
||||
*.py[cod] |
||||
*$py.class |
||||
|
||||
# C extensions |
||||
*.so |
||||
|
||||
# Distribution / packaging |
||||
.Python |
||||
build/ |
||||
develop-eggs/ |
||||
dist/ |
||||
downloads/ |
||||
eggs/ |
||||
.eggs/ |
||||
lib/ |
||||
lib64/ |
||||
parts/ |
||||
sdist/ |
||||
var/ |
||||
wheels/ |
||||
pip-wheel-metadata/ |
||||
share/python-wheels/ |
||||
*.egg-info/ |
||||
.installed.cfg |
||||
*.egg |
||||
MANIFEST |
||||
|
||||
# PyInstaller |
||||
# Usually these files are written by a python script from a template |
||||
# before PyInstaller builds the exe, so as to inject date/other infos into it. |
||||
*.manifest |
||||
*.spec |
||||
|
||||
# Installer logs |
||||
pip-log.txt |
||||
pip-delete-this-directory.txt |
||||
|
||||
# Unit test / coverage reports |
||||
htmlcov/ |
||||
.tox/ |
||||
.nox/ |
||||
.coverage |
||||
.coverage.* |
||||
.cache |
||||
nosetests.xml |
||||
coverage.xml |
||||
*.cover |
||||
*.py,cover |
||||
.hypothesis/ |
||||
.pytest_cache/ |
||||
|
||||
# Translations |
||||
*.mo |
||||
*.pot |
||||
|
||||
# Django stuff: |
||||
*.log |
||||
local_settings.py |
||||
db.sqlite3 |
||||
db.sqlite3-journal |
||||
|
||||
# Flask stuff: |
||||
instance/ |
||||
.webassets-cache |
||||
|
||||
# Scrapy stuff: |
||||
.scrapy |
||||
|
||||
# Sphinx documentation |
||||
docs/_build/ |
||||
docs/.build/ |
||||
|
||||
# PyBuilder |
||||
target/ |
||||
|
||||
# Jupyter Notebook |
||||
.ipynb_checkpoints |
||||
|
||||
# IPython |
||||
profile_default/ |
||||
ipython_config.py |
||||
|
||||
# pyenv |
||||
.python-version |
||||
|
||||
# pipenv |
||||
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. |
||||
# However, in case of collaboration, if having platform-specific dependencies or dependencies |
||||
# having no cross-platform support, pipenv may install dependencies that don't work, or not |
||||
# install all needed dependencies. |
||||
#Pipfile.lock |
||||
|
||||
# PEP 582; used by e.g. github.com/David-OConnor/pyflow |
||||
__pypackages__/ |
||||
|
||||
# Celery stuff |
||||
celerybeat-schedule |
||||
celerybeat.pid |
||||
|
||||
# SageMath parsed files |
||||
*.sage.py |
||||
|
||||
# Environments |
||||
.env |
||||
.venv |
||||
env/ |
||||
venv/ |
||||
ENV/ |
||||
env.bak/ |
||||
venv.bak/ |
||||
|
||||
# Spyder project settings |
||||
.spyderproject |
||||
.spyproject |
||||
|
||||
# Rope project settings |
||||
.ropeproject |
||||
|
||||
# mkdocs documentation |
||||
/site |
||||
|
||||
# mypy |
||||
.mypy_cache/ |
||||
.dmypy.json |
||||
dmypy.json |
||||
|
||||
# Pyre type checker |
||||
.pyre/ |
||||
|
||||
# IDE |
||||
.idea/ |
||||
.vscode/ |
||||
|
||||
# macos |
||||
.DS_Store |
||||
#data/ |
||||
|
||||
# launcher setting |
||||
tests/launcher/log |
||||
tests/launcher/personal |
||||
|
||||
docs/.build |
@ -0,0 +1,4 @@
|
||||
include *.txt README.md |
||||
recursive-include requirements *.txt |
||||
recursive-include colossalai *.cpp *.h *.cu *.tr *.cuh *.cc |
||||
recursive-include csrc *.cpp *.h *.cu *.tr *.cuh *.cc |
@ -0,0 +1,104 @@
|
||||
# ColossalAI |
||||
|
||||
An integrated large-scale model training framework with efficient parallelization techniques |
||||
|
||||
## Installation |
||||
|
||||
### PyPI |
||||
|
||||
```bash |
||||
pip install colossalai |
||||
``` |
||||
|
||||
### Install From Source |
||||
|
||||
```shell |
||||
git clone git@github.com:hpcaitech/ColossalAI.git |
||||
cd ColossalAI |
||||
# install dependency |
||||
pip install -r requirements/requirements.txt |
||||
|
||||
# install colossalai |
||||
pip install . |
||||
``` |
||||
|
||||
Install and enable CUDA kernel fusion (compulsory installation when using fused optimizer) |
||||
|
||||
```shell |
||||
pip install -v --no-cache-dir --global-option="--cuda_ext" . |
||||
``` |
||||
|
||||
## Documentation |
||||
|
||||
- [Documentation](https://www.colossalai.org/) |
||||
|
||||
## Quick View |
||||
|
||||
### Start Distributed Training in Lines |
||||
|
||||
```python |
||||
import colossalai |
||||
from colossalai.engine import Engine |
||||
from colossalai.trainer import Trainer |
||||
from colossalai.core import global_context as gpc |
||||
|
||||
model, train_dataloader, test_dataloader, criterion, optimizer, schedule, lr_scheduler = colossalai.initialize() |
||||
engine = Engine( |
||||
model=model, |
||||
criterion=criterion, |
||||
optimizer=optimizer, |
||||
lr_scheduler=lr_scheduler, |
||||
schedule=schedule |
||||
) |
||||
|
||||
trainer = Trainer(engine=engine, |
||||
hooks_cfg=gpc.config.hooks, |
||||
verbose=True) |
||||
trainer.fit( |
||||
train_dataloader=train_dataloader, |
||||
test_dataloader=test_dataloader, |
||||
max_epochs=gpc.config.num_epochs, |
||||
display_progress=True, |
||||
test_interval=5 |
||||
) |
||||
``` |
||||
|
||||
### Write a Simple 2D Parallel Model |
||||
|
||||
Let's say we have a huge MLP model and its very large hidden size makes it difficult to fit into a single GPU. We can |
||||
then distribute the model weights across GPUs in a 2D mesh while you still write your model in a familiar way. |
||||
|
||||
```python |
||||
from colossalai.nn import Linear2D |
||||
import torch.nn as nn |
||||
|
||||
|
||||
class MLP_2D(nn.Module): |
||||
|
||||
def __init__(self): |
||||
super().__init__() |
||||
self.linear_1 = Linear2D(in_features=1024, out_features=16384) |
||||
self.linear_2 = Linear2D(in_features=16384, out_features=1024) |
||||
|
||||
def forward(self, x): |
||||
x = self.linear_1(x) |
||||
x = self.linear_2(x) |
||||
return x |
||||
|
||||
``` |
||||
|
||||
## Features |
||||
|
||||
ColossalAI provides a collection of parallel training components for you. We aim to support you to write your |
||||
distributed deep learning models just like how you write your single-GPU model. We provide friendly tools to kickstart |
||||
distributed training in a few lines. |
||||
|
||||
- [Data Parallelism](./docs/parallelization.md) |
||||
- [Pipeline Parallelism](./docs/parallelization.md) |
||||
- [1D, 2D, 2.5D, 3D and sequence parallelism](./docs/parallelization.md) |
||||
- [friendly trainer and engine](./docs/trainer_engine.md) |
||||
- [Extensible for new parallelism](./docs/add_your_parallel.md) |
||||
- [Mixed Precision Training](./docs/amp.md) |
||||
- [Zero Redundancy Optimizer (ZeRO)](./docs/zero.md) |
||||
|
||||
|
@ -0,0 +1,4 @@
|
||||
from .initialize import init_dist, initialize |
||||
from .nn import * |
||||
|
||||
__version__ = '0.0.1' |
@ -0,0 +1,2 @@
|
||||
from .builder import * |
||||
from .pipeline import ModelInitializer |
@ -0,0 +1,262 @@
|
||||
#!/usr/bin/env python |
||||
# -*- encoding: utf-8 -*- |
||||
|
||||
import inspect |
||||
from collections.abc import Iterable |
||||
|
||||
from colossalai.registry import * |
||||
|
||||
|
||||
def build_from_config(module, config: dict): |
||||
"""Returns an object of :class:`module` constructed from `config`. |
||||
|
||||
:param module: A python or user-defined class |
||||
:type module: class |
||||
:param config: A python dict containing information used in the construction |
||||
of the return object |
||||
:type config: dict |
||||
:raises AssertionError: Raises an AssertionError if `module` is not a class |
||||
:return: An object of :class:`module` |
||||
:rtype: :class:`module` |
||||
""" |
||||
assert inspect.isclass(module), 'module must be a class' |
||||
return module(**config) |
||||
|
||||
|
||||
def build_from_registry(config, registry: Registry): |
||||
"""Returns an object constructed from `config`, the type of the object |
||||
is specified by `registry`. |
||||
|
||||
:param config: A python dict or a :class:`colossalai.context.Config` object |
||||
containing information used in the construction of the return object |
||||
:type config: dict or :class:`colossalai.context.colossalai.context.Config` |
||||
:param registry: A registry specifying the type of the return object |
||||
:type registry: :class:`Registry` |
||||
:raises AssertionError: Raises an AssertionError if `registry` is not an object |
||||
of :class:`Registry` or `mod_type` in `config` is not found in `registry` |
||||
:raises Exception: Raises an Exception if an error occurred when building |
||||
from registry |
||||
:return: An object specified by `registry` |
||||
:rtype: Python object specified by `registry` |
||||
""" |
||||
config_ = config.copy() # keep the original config untouched |
||||
assert isinstance( |
||||
registry, Registry), f'Expected type Registry but got {type(registry)}' |
||||
|
||||
mod_type = config_.pop('type') |
||||
assert registry.has( |
||||
mod_type), f'{mod_type} is not found in registry {registry.name}' |
||||
try: |
||||
obj = registry.get_module(mod_type)(**config_) |
||||
except Exception as e: |
||||
print( |
||||
f'An error occurred when building {mod_type} from registry {registry.name}', flush=True) |
||||
raise e |
||||
|
||||
return obj |
||||
|
||||
|
||||
def build_layer(config): |
||||
"""Returns a layer object of :class:`nn.Module` constructed from `config`. |
||||
|
||||
:param config: A python dict or a :class:`colossalai.context.Config` object |
||||
containing information used in the construction of the return object |
||||
:type config: dict or :class:`colossalai.context.Config` |
||||
:return: An object of :class:`nn.Module` |
||||
:rtype: :class:`nn.Module` |
||||
""" |
||||
return build_from_registry(config, LAYERS) |
||||
|
||||
|
||||
def build_loss(config): |
||||
"""Returns a loss function object of :class:`torch.autograd.Function` constructed |
||||
from `config`. |
||||
|
||||
:param config: A python dict or a :class:`colossalai.context.Config` object |
||||
containing information used in the construction of the return object |
||||
:type config: dict or :class:`colossalai.context.Config` |
||||
:return: An object of :class:`torch.autograd.Function` |
||||
:rtype: :class:`torch.autograd.Function` |
||||
""" |
||||
return build_from_registry(config, LOSSES) |
||||
|
||||
|
||||
def build_model(config): |
||||
"""Returns a model object of :class:`nn.Module` constructed from `config`. |
||||
|
||||
:param config: A python dict or a :class:`colossalai.context.Config` object |
||||
containing information used in the construction of the return object |
||||
:type config: dict or :class:`colossalai.context.Config` |
||||
:return: An object of :class:`nn.Module` |
||||
:rtype: :class:`nn.Module` |
||||
""" |
||||
return build_from_registry(config, MODELS) |
||||
|
||||
|
||||
def build_dataset(config): |
||||
"""Returns a dataset object of :class:`torch.utils.data.Dataset` constructed |
||||
from `config`. |
||||
|
||||
:param config: A python dict or a :class:`colossalai.context.Config` object |
||||
containing information used in the construction of the return object |
||||
:type config: dict or :class:`colossalai.context.Config` |
||||
:return: An object of :class:`torch.utils.data.Dataset` |
||||
:rtype: :class:`torch.utils.data.Dataset` |
||||
""" |
||||
return build_from_registry(config, DATASETS) |
||||
|
||||
|
||||
def build_optimizer(config, model, params: Iterable = None, need_module=False): |
||||
"""Returns an optimizer object of :class:`torch.optim.Optimizer` constructed from `config`, |
||||
'model' and 'params'. |
||||
|
||||
:param config: A python dict or a :class:`colossalai.context.Config` object |
||||
containing information used in the construction of the return object |
||||
:type config: dict or :class:`colossalai.context.Config` |
||||
:param model: A model containing parameters for the optimizer |
||||
:type model: :class:`nn.Module` |
||||
:param params: A dict containing parameters for the optimizer |
||||
:type params: dict, optional |
||||
:param need_module: Indicates whether the optimizer needs a module |
||||
:type params: bool, optional |
||||
:raises AssertionError: Raises an AssertionError if both `model` and `params` are None |
||||
:return: An object of :class:`torch.optim.Optimizer` |
||||
:rtype: :class:`torch.optim.Optimizer` |
||||
""" |
||||
assert model is not None or params is not None, 'arguments model and params can not both be None' |
||||
if need_module: |
||||
config['module'] = model |
||||
elif model is not None: |
||||
config['params'] = model.parameters() |
||||
elif params is not None: |
||||
config['params'] = params |
||||
|
||||
return build_from_registry(config, OPTIMIZERS) |
||||
|
||||
|
||||
def build_gradient_handler(config, model, optimizer): |
||||
"""Returns a gradient handler object of :class:`BaseGradientHandler` constructed from `config`, |
||||
`model` and `optimizer`. |
||||
|
||||
:param config: A python dict or a :class:`colossalai.context.Config` object |
||||
containing information used in the construction of the return object |
||||
:type config: dict or :class:`colossalai.context.Config` |
||||
:param model: A model containing parameters for the gradient handler |
||||
:type model: :class:`nn.Module` |
||||
:param optimizer: An optimizer object containing parameters for the gradient handler |
||||
:type optimizer: :class:`torch.optim.Optimizer` |
||||
:return: An object of :class:`BaseGradientHandler` |
||||
:rtype: :class:`BaseGradientHandler` |
||||
""" |
||||
config_ = config.copy() |
||||
mod_type = config_.pop('type') |
||||
return GRADIENT_HANDLER.get_module(mod_type)(model, optimizer, **config_) |
||||
|
||||
|
||||
def build_hooks(config, trainer): |
||||
"""Returns a hook object of :class:`BaseHook` constructed from `config` and `trainer`. |
||||
|
||||
:param config: A python dict or a :class:`colossalai.context.Config` object |
||||
containing information used in the construction of the return object |
||||
:type config: dict or :class:`colossalai.context.Config` |
||||
:param trainer: A :class:`Trainer` object containing parameters for the hook |
||||
:type trainer: :class:`Trainer` |
||||
:return: An object of :class:`BaseHook` |
||||
:rtype: :class:`BaseHook` |
||||
""" |
||||
config['trainer'] = trainer |
||||
return build_from_registry(config, HOOKS) |
||||
|
||||
|
||||
def build_transform(config): |
||||
"""Returns a transformation object of :class:`torchvision.transforms` constructed |
||||
from `config`. |
||||
|
||||
:param config: A python dict or a :class:`colossalai.context.Config` object |
||||
containing information used in the construction of the return object |
||||
:type config: dict or :class:`colossalai.context.Config` |
||||
:return: An object of :class:`torchvision.transforms` |
||||
:rtype: :class:`torchvision.transforms` |
||||
""" |
||||
return build_from_registry(config, TRANSFORMS) |
||||
|
||||
|
||||
def build_pipe_alloc_policy(config): |
||||
"""Returns a pipeline allocation policy object constructed from `config`. |
||||
|
||||
:param config: A python dict or a :class:`colossalai.context.Config` object |
||||
containing information used in the construction of the return object |
||||
:type config: dict or :class:`colossalai.context.Config` |
||||
:return: A pipeline allocation policy object |
||||
:rtype: |
||||
""" |
||||
return build_from_registry(config, PIPE_ALLOC_POLICY) |
||||
|
||||
|
||||
def build_data_sampler(config, dataset): |
||||
"""Returns a data sampler object of :class:`colossalai.nn.data.sampler.BaseSampler` |
||||
constructed from `config`. |
||||
|
||||
:param config: A python dict or a :class:`colossalai.context.Config` object |
||||
containing information used in the construction of the return object |
||||
:type config: dict or :class:`colossalai.context.Config` |
||||
:param dataset: An object of :class:`torch.utils.data.Dataset` containing information |
||||
used in the construction of the return object |
||||
:type dataset: :class:`torch.utils.data.Dataset` |
||||
:return: An object of :class:`colossalai.nn.data.sampler.BaseSampler` |
||||
:rtype: :class:`colossalai.nn.data.sampler.BaseSampler` |
||||
""" |
||||
config_ = config.copy() |
||||
mod_type = config_.pop('type') |
||||
return SAMPLERS.get_module(mod_type)(dataset, **config_) |
||||
|
||||
|
||||
def build_optimizer_wrapper(config, optimizer, model=None): |
||||
"""Returns an optimizer wrapper object of :class:`torch.optim.Optimizer` constructed |
||||
from `config`, `model` and `optimizer`. |
||||
|
||||
:param config: A python dict or a :class:`colossalai.context.Config` object |
||||
containing information used in the construction of the return object |
||||
:type config: dict or :class:`colossalai.context.Config` |
||||
:param optimizer: An optimizer object containing parameters for the gradient handler |
||||
:type optimizer: :class:`torch.optim.Optimizer` |
||||
:param model: A model containing parameters for the gradient handler |
||||
:type model: :class:`nn.Module`, optional |
||||
:return: An object of :class:`torch.optim.Optimizer` |
||||
:rtype: :class:`torch.optim.Optimizer` |
||||
""" |
||||
config_ = config.copy() |
||||
mod_type = config_.pop('type') |
||||
|
||||
# LSG: special treatment for zeor level 3 |
||||
if mod_type == 'ZeroRedundancyOptimizer_Level_3': |
||||
return OPTIMIZER_WRAPPERS.get_module(mod_type)(model, optimizer, **config_) |
||||
else: |
||||
return OPTIMIZER_WRAPPERS.get_module(mod_type)(optimizer, **config_) |
||||
|
||||
|
||||
def build_lr_scheduler(config, optimizer, total_steps, num_steps_per_epoch): |
||||
"""Returns a learning rate scheduler object of :class:`torch.optim.lr_scheduler` |
||||
constructed from `config`, `optimizer`, `total_steps` and `num_steps_per_epoch`. |
||||
|
||||
:param config: A python dict or a :class:`colossalai.context.Config` object |
||||
containing information used in the construction of the return object |
||||
:type config: dict or :class:`colossalai.context.Config` |
||||
:param optimizer: An optimizer object containing parameters for the learning rate |
||||
scheduler |
||||
:type optimizer: :class:`torch.optim.Optimizer` |
||||
:param total_steps: Number of total steps of the learning rate scheduler |
||||
:type total_steps: int |
||||
:param num_steps_per_epoch: number of steps per epoch of the learning rate scheduler |
||||
:type num_steps_per_epoch: int |
||||
:return: An object of :class:`torch.optim.lr_scheduler` |
||||
:rtype: :class:`torch.optim.lr_scheduler` |
||||
""" |
||||
config_ = config.copy() |
||||
mod_type = config_.pop('type') |
||||
# warmup epochs will overwrite warmup steps |
||||
if 'warmup_epochs' in config_: |
||||
warmup_epochs = config_.pop('warmup_epochs') |
||||
config_['warmup_steps'] = int(num_steps_per_epoch * warmup_epochs) |
||||
return LR_SCHEDULERS.get_module(mod_type)(optimizer, total_steps, num_steps_per_epoch=num_steps_per_epoch, |
||||
**config_) |
@ -0,0 +1,226 @@
|
||||
import copy |
||||
import heapq |
||||
|
||||
from colossalai.builder import build_model, build_layer |
||||
from colossalai.context.parallel_mode import ParallelMode |
||||
from colossalai.core import global_context as gpc |
||||
from colossalai.logging import get_global_dist_logger |
||||
from colossalai.utils import set_to_cuda |
||||
|
||||
|
||||
def _binary_partition(weights, st, ed): |
||||
"""Returns the binary partition position of `weights`, given the start |
||||
position `st` and the end position `ed`. |
||||
|
||||
:param weights: A python list to be binary partitioned |
||||
:type weights: list |
||||
:param st: the start position of the binary partition |
||||
:type st: int |
||||
:param ed: the end postition of the binary partition |
||||
:type ed: int |
||||
:return: the binary partition position of `weights` |
||||
:rtype: int |
||||
""" |
||||
w_sum = weights[ed - 1] |
||||
prefix = 0 |
||||
if st > 0: |
||||
w_sum -= weights[st - 1] |
||||
prefix = weights[st - 1] |
||||
minimum = float("inf") |
||||
for idx in range(st + 1, ed): |
||||
front = weights[idx - 1] - prefix |
||||
diff = abs(w_sum - 2 * front) |
||||
if diff < minimum: |
||||
pos = idx |
||||
minimum = diff |
||||
|
||||
return st, pos, ed |
||||
|
||||
|
||||
def _heap_addition(weights, intervals, add_cnt): |
||||
""" |
||||
""" |
||||
def _heap_push(heap, st, ed): |
||||
value = weights[ed - 1] |
||||
if st > 0: |
||||
value -= weights[st - 1] |
||||
heapq.heappush(heap, (-value, st, ed)) |
||||
|
||||
ret_intervals = [] |
||||
heap = [] |
||||
|
||||
for st, ed in intervals: |
||||
_heap_push(heap, st, ed) |
||||
|
||||
while add_cnt > 0: |
||||
_, st, ed = heapq.heappop(heap) |
||||
if ed - st == 1: |
||||
ret_intervals.append((st, ed)) |
||||
else: |
||||
l, m, r = _binary_partition(weights, st, ed) |
||||
_heap_push(heap, l, m) |
||||
_heap_push(heap, m, r) |
||||
add_cnt -= 1 |
||||
|
||||
while heap: |
||||
_, st, ed = heapq.heappop(heap) |
||||
ret_intervals.append((st, ed)) |
||||
|
||||
ret_intervals.sort() |
||||
return ret_intervals |
||||
|
||||
|
||||
def _calc_partitions(weights, value): |
||||
prev = 0 |
||||
prefix = 0 |
||||
num_block = 0 |
||||
intervals = [] |
||||
|
||||
for idx, w in enumerate(weights): |
||||
if weights[idx] - prefix > value: |
||||
intervals.append((prev, idx)) |
||||
prev = idx |
||||
prefix = weights[idx - 1] |
||||
num_block += 1 |
||||
|
||||
intervals.append((prev, len(weights))) |
||||
return num_block + 1, intervals |
||||
|
||||
|
||||
def _binary_search(weights, num): |
||||
length = len(weights) |
||||
prefix = [1 if w == 0 else w for w in weights] |
||||
for i in range(1, length): |
||||
prefix[i] += prefix[i - 1] |
||||
|
||||
lower_bound = max(weights) |
||||
upper_bound = prefix[length - 1] |
||||
|
||||
while upper_bound > lower_bound: |
||||
mid = (upper_bound + lower_bound) // 2 |
||||
number, _ = _calc_partitions(prefix, mid) |
||||
if number <= num: |
||||
upper_bound = mid |
||||
else: |
||||
lower_bound = mid + 1 |
||||
|
||||
num_block, intervals = _calc_partitions(prefix, upper_bound) |
||||
if num_block < num: |
||||
intervals = _heap_addition(prefix, intervals, num - num_block) |
||||
|
||||
return intervals |
||||
|
||||
|
||||
def _partition_uniform(num_items, num_parts, num_chunks): |
||||
assert num_items % num_chunks == 0, \ |
||||
"Layer length should be divided by the number of chunks, otherwise parameter method is recomended" |
||||
|
||||
logger = get_global_dist_logger() |
||||
parts = [[] for _ in range(num_parts)] |
||||
partition_items = num_items // num_chunks |
||||
for idx in range(num_chunks): |
||||
base_idx = idx * partition_items |
||||
chunk_size = partition_items // num_parts |
||||
left = num_parts - partition_items % num_parts |
||||
if chunk_size == 0: |
||||
logger.warning("Some nodes in Pipeline have no requests") |
||||
|
||||
for p in range(num_parts): |
||||
st = base_idx |
||||
base_idx += chunk_size + (p >= left) |
||||
parts[p].append((st, base_idx)) |
||||
|
||||
return parts |
||||
|
||||
|
||||
def _partition_balanced(weights, num_parts, num_chunks): |
||||
num_total = num_parts * num_chunks |
||||
num_items = len(weights) |
||||
if num_items <= num_total: |
||||
return _partition_uniform(num_items, num_parts, num_chunks) |
||||
|
||||
intervals = _binary_search(weights, num_total) |
||||
|
||||
current = 0 |
||||
parts = [[] for _ in range(num_parts)] |
||||
for inter in intervals: |
||||
parts[current].append(inter) |
||||
current = (current + 1) % num_parts |
||||
|
||||
return parts |
||||
|
||||
|
||||
class ModelInitializer(): |
||||
def __init__(self, config, num_chunks, verbose=False): |
||||
self.num_chunks = num_chunks |
||||
self.ori_model = build_model(config) |
||||
self.layers = self.ori_model.layers_cfg |
||||
layer_length = len(self.layers) |
||||
self.verbose = verbose |
||||
self._logger = get_global_dist_logger() |
||||
self._logger.info(f"The total length of layers is {layer_length}", ranks=[0]) |
||||
|
||||
def model_initialize(self, partition_method='parameter'): |
||||
# Some space for initializing comunication groups |
||||
self._interval = None |
||||
self._partition_layers(method=partition_method) |
||||
models = self._build() |
||||
model = set_to_cuda(models) |
||||
|
||||
return model |
||||
|
||||
def _partition_layers(self, method): |
||||
pipeline_parallel_size = gpc.get_world_size(ParallelMode.PIPELINE) |
||||
pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE) |
||||
|
||||
method = method.lower() |
||||
# Make a partition |
||||
if method == 'layer': |
||||
num_layers = len(self.layers) |
||||
self.parts = _partition_uniform(num_layers, pipeline_parallel_size, self.num_chunks) |
||||
elif method == 'parameter': |
||||
param_counts = self._count_layer_params() |
||||
# print_rank_0(param_counts) |
||||
self.parts = _partition_balanced(param_counts, pipeline_parallel_size, self.num_chunks) |
||||
else: |
||||
assert method == 'layer', "Method should be a pre-set string" |
||||
|
||||
# Display the partition |
||||
if gpc.get_global_rank() == 0 and self.verbose: |
||||
log_str = 'Layer allocation after partitioning: \n' |
||||
for stage in range(pipeline_parallel_size): |
||||
|
||||
num_layers = 0 |
||||
for st, ed in self.parts[stage]: |
||||
num_layers += ed - st |
||||
|
||||
log_str += f'\n===== stage={stage}, layers={num_layers} =====\n' |
||||
for st, ed in self.parts[stage]: |
||||
for idx, layer in enumerate(self.layers[st: ed]): |
||||
log_str += f'\t{idx + st:2d}: {layer}\n' |
||||
self._logger.info(log_str) |
||||
|
||||
# Save the partition |
||||
self._interval = self.parts[pipeline_rank] |
||||
|
||||
def _build(self): |
||||
"""Build model from the layer cfg according to the partition |
||||
""" |
||||
models = [] |
||||
for st, ed in self._interval: |
||||
model = copy.copy(self.ori_model) |
||||
model.build_from_cfg(st, ed) |
||||
models.append(model) |
||||
|
||||
return models |
||||
|
||||
def _count_layer_params(self): |
||||
"""Count the number of parameters in each layer |
||||
""" |
||||
param_counts = [0] * len(self.layers) |
||||
for idx, cfg in enumerate(self.layers): |
||||
layer = build_layer(cfg) |
||||
params = filter(lambda p: p.requires_grad, layer.parameters()) |
||||
param_counts[idx] = sum(p.numel() for p in params) |
||||
|
||||
return param_counts |
@ -0,0 +1,215 @@
|
||||
import os |
||||
import os.path as osp |
||||
import re |
||||
from typing import Tuple |
||||
|
||||
import torch |
||||
|
||||
from .context import Config |
||||
from .context.parallel_mode import ParallelMode |
||||
from .core import global_context as gpc |
||||
|
||||
__all__ = [ |
||||
'get_checkpoint_path', |
||||
'get_latest_checkpoint_path', |
||||
'get_latest_checkpoint_pattern', |
||||
'save_checkpoint', |
||||
'load_checkpoint' |
||||
] |
||||
|
||||
|
||||
def unwrap_config(config: Config): |
||||
''' |
||||
unwrap Config objects to normal dicts |
||||
''' |
||||
config_dict = dict() |
||||
for k, v in config.items(): |
||||
if isinstance(v, dict): |
||||
config_dict[k] = unwrap_config(v) |
||||
else: |
||||
config_dict[k] = v |
||||
|
||||
return config_dict |
||||
|
||||
|
||||
def _get_ranks_name(): |
||||
# tensor parallel |
||||
tp_local_rank = 0 |
||||
if gpc.is_initialized(ParallelMode.TENSOR): |
||||
tp_local_rank = gpc.get_local_rank(ParallelMode.TENSOR) |
||||
|
||||
# pipeline parallel |
||||
pp_local_rank = 0 |
||||
if gpc.is_initialized(ParallelMode.PIPELINE): |
||||
pp_local_rank = gpc.get_local_rank(ParallelMode.PIPELINE) |
||||
|
||||
ranks_name = f'tp{tp_local_rank}-pp{pp_local_rank}' |
||||
return ranks_name |
||||
|
||||
|
||||
def _get_standard_checkpoint_filename(epoch: int, suffix: str = ''): |
||||
ranks_name = _get_ranks_name() |
||||
return f'epoch{epoch}-{ranks_name}{suffix}.pt' |
||||
|
||||
|
||||
def get_checkpoint_path(checkpoint_dir: str, epoch: int, suffix: str = ''): |
||||
'''This is a function to generate the checkpoint path from the (checkpoint_dir, epoch, suffix, gpu_parallel_rank) tuple. |
||||
This is useful during generation and recuperation of the checkpoint. |
||||
|
||||
:param checkpoint_dir: set up a directory for saving checkpoints |
||||
:type checkpoint_dir: str |
||||
:param epoch: epoch number (indicate how many epochs have you trained this model) |
||||
:type epoch: int |
||||
:param suffix: additional notation to specify the model or checkpoint, defaults to '' |
||||
:type suffix: str, optional |
||||
:return: checkpoint path to be generated |
||||
:rtype: path |
||||
''' |
||||
ckpt_filename = _get_standard_checkpoint_filename(epoch, suffix) |
||||
return os.path.join(checkpoint_dir, ckpt_filename) |
||||
|
||||
|
||||
def _ensure_directory_exists(filename: str): |
||||
# ensure the directory exists |
||||
dir = os.path.dirname(filename) |
||||
if not os.path.exists(dir): |
||||
os.makedirs(dir) |
||||
|
||||
|
||||
def get_latest_checkpoint_pattern(suffix: str = ''): |
||||
'''Generate Regular expression of latest checkpoint's pattern |
||||
|
||||
:param suffix: additional notation to specify the model or checkpoint, defaults to '' |
||||
:type suffix: str, optional |
||||
:return: checkpoint pattern |
||||
:rtype: regular expression |
||||
''' |
||||
ranks_name = _get_ranks_name() |
||||
ckpt_pattern = re.compile(f'epoch(\d+)-{ranks_name}{suffix}\.pt') |
||||
return ckpt_pattern |
||||
|
||||
|
||||
def get_latest_checkpoint_path(checkpoint_dir: str, suffix: str = ''): |
||||
'''This is a function to retrieve the latest checkpoint path from the (checkpoint_dir, suffix, gpu_parallel_rank) tuple. |
||||
This is useful during recuperation of the checkpoint, especially when you do not know the epoch number. |
||||
|
||||
:param checkpoint_dir: directory for saving checkpoints |
||||
:type checkpoint_dir: str |
||||
:param suffix: additional notation to specify the model or checkpoint, defaults to '' |
||||
:type suffix: str, optional |
||||
:raises FileNotFoundError: raise error when we cannot find the latest checkpoint file with inputs given |
||||
:return: the latest checkpoint path to be retrieved |
||||
:rtype: path |
||||
''' |
||||
CKPT_NAME_PAT = get_latest_checkpoint_pattern(suffix=suffix) |
||||
|
||||
last_epoch = -1 |
||||
assert osp.isdir(checkpoint_dir), f'{checkpoint_dir} is not a directory' |
||||
|
||||
for filename in os.listdir(checkpoint_dir): |
||||
ret = CKPT_NAME_PAT.match(filename) |
||||
if ret: |
||||
epoch = int(ret[0].split('-')[0].lstrip('epoch')) |
||||
if epoch > last_epoch: |
||||
last_epoch = epoch |
||||
|
||||
if last_epoch == -1: |
||||
ranks_name = _get_ranks_name() |
||||
raise FileNotFoundError(f"Cannot find the latest checkpoint file for {ranks_name} in {checkpoint_dir}") |
||||
else: |
||||
target_file = _get_standard_checkpoint_filename(last_epoch, suffix=suffix) |
||||
path = osp.join(checkpoint_dir, target_file) |
||||
return path |
||||
|
||||
|
||||
def save_checkpoint(checkpoint_path: str, |
||||
epoch: int, |
||||
model: torch.nn.Module, |
||||
optimizer: torch.optim.Optimizer, |
||||
lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None, |
||||
**kwargs): |
||||
'''Given a directory to store the checkpoints, saves all the training components' parameters or buffers, such as model, optimizer, lr_scheduler and etc. into a checkpoint dictionary. |
||||
|
||||
This method can be used for both colosalai nn.BaseModel and normal pytorch nn.Module. |
||||
|
||||
|
||||
:param checkpoint_path: set up a directory for saving checkpoints |
||||
:type checkpoint_path: str |
||||
:param epoch: epoch number (indicate how many epochs have you trained this model) |
||||
:type epoch: int |
||||
:param model: model to be registered |
||||
:type model: torch.nn.Module |
||||
:param optimizer: optimizer to be registered |
||||
:type optimizer: torch.optim.Optimizer |
||||
:param lr_scheduler: lr_scheduler to be registered, defaults to None |
||||
:type lr_scheduler: torch.optim.lr_scheduler._LRScheduler, optional |
||||
''' |
||||
# for compatibility with normal pytorch nn.Module |
||||
if hasattr(model, 'state_dict_for_save_checkpoint'): |
||||
model_sd = model.state_dict_for_save_checkpoint() |
||||
else: |
||||
model_sd = model.state_dict() |
||||
|
||||
# ckpt container |
||||
checkpoint = { |
||||
'epoch': epoch, |
||||
'model': model_sd, |
||||
'optimizer': optimizer.state_dict(), |
||||
**kwargs |
||||
} |
||||
if lr_scheduler is not None: |
||||
checkpoint['lr_scheduler'] = lr_scheduler.state_dict() |
||||
|
||||
_ensure_directory_exists(checkpoint_path) |
||||
torch.save(checkpoint, checkpoint_path) |
||||
|
||||
|
||||
def load_checkpoint(checkpoint_path: str, |
||||
model: torch.nn.Module, |
||||
optimizer: torch.optim.Optimizer, |
||||
lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None, |
||||
finetune: bool = False, |
||||
strict: bool = True) -> Tuple: |
||||
'''Loads the checkpoint file. |
||||
If finetune is False, then we intend to continue/resume the training process from the checkpoint given. |
||||
So we copy parameters and buffers from state_dict into these modules(model, optimizer,lr_scheduler) and its descendants. |
||||
If finetune is True, then only the weights and buffers of model should be reload. |
||||
If strict is True, then the keys of state_dict must exactly match the keys returned by this module’s state_dict() function. |
||||
|
||||
:param checkpoint_path: the exact and matched checkpoint_path directory to retrieve appropriate state_dict |
||||
:type checkpoint_path: str |
||||
:param model: model to reload parameters and buffers |
||||
:type model: torch.nn.Module |
||||
:param optimizer: optimizer to recuperate |
||||
:type optimizer: torch.optim.Optimizer |
||||
:param lr_scheduler: lr_scheduler to recuperate, defaults to None |
||||
:type lr_scheduler: torch.optim.lr_scheduler._LRScheduler, optional |
||||
:param finetune: whether to finetune the model with new dataset or continue the pre-training, defaults to False |
||||
:type finetune: bool, optional |
||||
:param strict: whether to strictly enforce that the keys in |
||||
:attr:`state_dict` of the checkpoint match the names of |
||||
parameters and buffers in model., defaults to True |
||||
:type strict: bool, optional |
||||
:raises ValueError: raise error if the model/optimizer cannot successfully be recuperated |
||||
:return: (the epoch number of the checkpoint retrieved, the checkpoint retrieved) |
||||
:rtype: Tuple |
||||
|
||||
''' |
||||
# Load the checkpoint. |
||||
checkpoint = torch.load(checkpoint_path, map_location='cpu') |
||||
try: |
||||
last_epoch = checkpoint.pop('epoch') if not finetune else 0 |
||||
model.load_state_dict(checkpoint.pop('model'), strict=strict) |
||||
except KeyError: |
||||
raise ValueError('Checkpoint is corrupted') |
||||
|
||||
if not finetune: |
||||
try: |
||||
optimizer.load_state_dict(checkpoint.pop('optimizer')) |
||||
except KeyError: |
||||
raise ValueError('Checkpoint is corrupted') |
||||
|
||||
if lr_scheduler is not None and 'lr_scheduler' in checkpoint: |
||||
lr_scheduler.load_state_dict(checkpoint.pop('lr_scheduler')) |
||||
|
||||
return last_epoch, checkpoint |
@ -0,0 +1,14 @@
|
||||
from .collective import all_gather, reduce_scatter, scatter |
||||
from .p2p import (send_forward, send_forward_recv_forward, send_backward_recv_forward, |
||||
send_backward, send_backward_recv_backward, send_forward_recv_backward, |
||||
send_forward_backward_recv_forward_backward, recv_forward, recv_backward) |
||||
from .ring import ring_forward |
||||
from .utils import send_tensor_meta, recv_tensor_meta |
||||
|
||||
__all__ = [ |
||||
'all_gather', 'reduce_scatter', 'scatter', |
||||
'send_forward', 'send_forward_recv_forward', 'send_forward_backward_recv_forward_backward', |
||||
'send_backward', 'send_backward_recv_backward', 'send_backward_recv_forward', |
||||
'send_forward_recv_backward', 'recv_backward', 'recv_forward', |
||||
'ring_forward', 'send_tensor_meta', 'recv_tensor_meta' |
||||
] |
@ -0,0 +1,84 @@
|
||||
#!/usr/bin/env python |
||||
# -*- encoding: utf-8 -*- |
||||
|
||||
import torch |
||||
import torch.distributed as dist |
||||
from torch import Tensor |
||||
|
||||
from colossalai.context import ParallelMode |
||||
from colossalai.core import global_context as gpc |
||||
from colossalai.utils import get_current_device |
||||
|
||||
|
||||
def all_gather(tensor: Tensor, dim: int, |
||||
parallel_mode: ParallelMode) -> Tensor: |
||||
"""Gathers all tensors from the parallel group and concatenates them in a |
||||
specific dimension. |
||||
|
||||
:param tensor: Tensor to be gathered |
||||
:param dim: The dimension concatenating in |
||||
:param parallel_mode: Parallel group mode used in this communication |
||||
:type tensor: Tensor |
||||
:type dim: int |
||||
:type parallel_mode: ParallelMode |
||||
:return: The tensor generated by all-gather |
||||
:rtype: Tensor |
||||
""" |
||||
depth = gpc.get_world_size(parallel_mode) |
||||
temp = tensor.clone() |
||||
shape = list(temp.shape) |
||||
shape[dim] *= depth |
||||
out = torch.empty(shape, dtype=temp.dtype, device=get_current_device()) |
||||
out = list(torch.chunk(out, depth, dim=dim)) |
||||
out = [val.contiguous() for val in out] |
||||
dist.all_gather(out, temp, group=gpc.get_group(parallel_mode)) |
||||
out = torch.cat(out, dim=dim) |
||||
return out |
||||
|
||||
|
||||
def reduce_scatter(tensor: Tensor, dim: int, |
||||
parallel_mode: ParallelMode) -> Tensor: |
||||
"""Reduces all tensors then scatters it in a specific dimension to all |
||||
members in the parallel group. |
||||
|
||||
:param tensor: Tensor to be reduced and scattered |
||||
:param dim: The dimension scattering in |
||||
:param parallel_mode: Parallel group mode used in this communication |
||||
:type tensor: Tensor |
||||
:type dim: int |
||||
:type parallel_mode: ParallelMode |
||||
:return: The tensor generated by reduce-scatter |
||||
:rtype: Tensor |
||||
""" |
||||
depth = gpc.get_world_size(parallel_mode) |
||||
temp = list(torch.chunk(tensor, depth, dim=dim)) |
||||
temp = [val.contiguous() for val in temp] |
||||
out = torch.empty(temp[0].shape, |
||||
dtype=temp[0].dtype, |
||||
device=get_current_device()) |
||||
dist.reduce_scatter(output=out, |
||||
input_list=temp, |
||||
group=gpc.get_group(parallel_mode)) |
||||
return out |
||||
|
||||
|
||||
def scatter(tensor: Tensor, src: int, dim: int, |
||||
parallel_mode: ParallelMode) -> Tensor: |
||||
"""Scatters in a specific dimension from source rank to all ranks in |
||||
the parallel group. |
||||
|
||||
:param tensor: Tensor to be scattered |
||||
:param dim: The dimension scattering in |
||||
:param parallel_mode: Parallel group mode used in this communication |
||||
:type tensor: Tensor |
||||
:type dim: int |
||||
:type parallel_mode: ParallelMode |
||||
:return: The tensor generated by scatter |
||||
:rtype: Tensor |
||||
""" |
||||
depth = gpc.get_world_size(parallel_mode) |
||||
temp = tensor.clone() |
||||
dist.broadcast(temp, src=src, group=gpc.get_group(parallel_mode)) |
||||
rank = gpc.get_local_rank(parallel_mode) |
||||
out = torch.chunk(temp, depth, dim=dim)[rank].contiguous() |
||||
return out |
@ -0,0 +1,333 @@
|
||||
#!/usr/bin/env python |
||||
# -*- encoding: utf-8 -*- |
||||
|
||||
import torch |
||||
import torch.distributed as dist |
||||
|
||||
from colossalai.context.parallel_mode import ParallelMode |
||||
from colossalai.core import global_context as gpc |
||||
from colossalai.utils import get_current_device |
||||
|
||||
|
||||
def _communicate(tensor_send_next=None, |
||||
tensor_send_prev=None, |
||||
recv_prev=False, |
||||
recv_next=False, |
||||
recv_prev_shape=None, |
||||
recv_next_shape=None, |
||||
prev_rank=None, |
||||
next_rank=None, |
||||
up_group=None, |
||||
down_group=None, |
||||
dtype=None): |
||||
""" |
||||
Adapted from megatron.p2p_communication. |
||||
Communicate tensors between stages. Used as helper method in other |
||||
communication methods that are used in pipeline schedule. |
||||
Takes the following arguments: |
||||
tensor_send_next: tensor to send to next rank (no tensor sent if |
||||
set to None). |
||||
tensor_send_prev: tensor to send to prev rank (no tensor sent if |
||||
set to None). |
||||
recv_prev: boolean for whether tensor should be received from |
||||
previous rank. |
||||
recv_next: boolean for whether tensor should be received from |
||||
next rank. |
||||
Returns: |
||||
(tensor_recv_prev, tensor_recv_next) |
||||
""" |
||||
|
||||
# Create placeholder tensors for receive in forward and backward directions |
||||
# if needed. |
||||
tensor_recv_prev = None |
||||
tensor_recv_next = None |
||||
|
||||
if recv_prev: |
||||
assert recv_prev_shape is not None |
||||
tensor_recv_prev = torch.empty(recv_prev_shape, |
||||
requires_grad=True, |
||||
device=get_current_device(), |
||||
dtype=dtype) |
||||
if recv_next: |
||||
assert recv_next_shape is not None |
||||
tensor_recv_next = torch.empty(recv_next_shape, |
||||
requires_grad=True, |
||||
device=get_current_device(), |
||||
dtype=dtype) |
||||
|
||||
if tensor_send_prev is not None or recv_prev: |
||||
if prev_rank is None: |
||||
prev_rank = gpc.get_prev_global_rank( |
||||
ParallelMode.PIPELINE) |
||||
if up_group is None: |
||||
up_group = gpc.get_group(ParallelMode.PIPELINE_PREV) |
||||
|
||||
if tensor_send_next is not None or recv_next: |
||||
if next_rank is None: |
||||
next_rank = gpc.get_next_global_rank( |
||||
ParallelMode.PIPELINE) |
||||
if down_group is None: |
||||
down_group = gpc.get_group(ParallelMode.PIPELINE_NEXT) |
||||
|
||||
# rank = dist.get_rank() |
||||
rank = gpc.get_global_rank() |
||||
|
||||
ops = [] |
||||
if tensor_send_prev is not None: |
||||
send_prev_op = dist.broadcast(tensor_send_prev, |
||||
src=rank, |
||||
group=up_group, |
||||
async_op=True) |
||||
ops.append(send_prev_op) |
||||
if tensor_recv_prev is not None: |
||||
recv_prev_op = dist.broadcast(tensor_recv_prev, |
||||
src=prev_rank, |
||||
group=up_group, |
||||
async_op=True) |
||||
ops.append(recv_prev_op) |
||||
if tensor_recv_next is not None: |
||||
recv_next_op = dist.broadcast(tensor_recv_next, |
||||
src=next_rank, |
||||
group=down_group, |
||||
async_op=True) |
||||
ops.append(recv_next_op) |
||||
if tensor_send_next is not None: |
||||
send_next_op = dist.broadcast(tensor_send_next, |
||||
src=rank, |
||||
group=down_group, |
||||
async_op=True) |
||||
ops.append(send_next_op) |
||||
for req in ops: |
||||
req.wait() |
||||
# To protect against race condition when using batch_isend_irecv(). |
||||
torch.cuda.synchronize() |
||||
return tensor_recv_prev, tensor_recv_next |
||||
|
||||
|
||||
def recv_forward(input_tensor_shape, prev_rank=None, up_group=None): |
||||
"""Receives the input tensor from the previous member in pipeline. |
||||
|
||||
:param input_tensor_shape: The shape of the tensor to be recieved |
||||
:param prev_rank: The rank of the source of the tensor |
||||
:param up_group: Communication group including the previous member in pipeline parallel group |
||||
:type input_tensor_shape: torch.Size |
||||
:type prev_rank: int, optional |
||||
:type up_group: ProcessGroup, optional |
||||
:return: The input tensor in forward step |
||||
:rtype: Tensor |
||||
""" |
||||
if gpc.is_first_rank(ParallelMode.PIPELINE): |
||||
input_tensor = None |
||||
else: |
||||
input_tensor, _ = _communicate(recv_prev=True, |
||||
recv_prev_shape=input_tensor_shape, |
||||
prev_rank=prev_rank, |
||||
up_group=up_group) |
||||
return input_tensor |
||||
|
||||
|
||||
def recv_backward(output_grad_shape, next_rank=None, down_group=None): |
||||
"""Receives the grad tensor from the next member in pipeline. |
||||
|
||||
:param output_grad_shape: The shape of the tensor to be recieved |
||||
:param next_rank: The rank of the source of the tensor |
||||
:param down_group: Communication group including the next member in pipeline parallel group |
||||
:type output_grad_shape: torch.Size |
||||
:type next_rank: int, optional |
||||
:type down_group: ProcessGroup, optional |
||||
:return: The grad of output tensor in forward step |
||||
:rtype: Tensor |
||||
""" |
||||
if gpc.is_last_rank(ParallelMode.PIPELINE): |
||||
output_tensor_grad = None |
||||
else: |
||||
_, output_tensor_grad = _communicate(recv_next=True, |
||||
recv_next_shape=output_grad_shape, |
||||
next_rank=next_rank, |
||||
down_group=down_group) |
||||
return output_tensor_grad |
||||
|
||||
|
||||
def send_forward(output_tensor, |
||||
next_rank=None, |
||||
down_group=None): |
||||
"""Sends the input tensor to the next member in pipeline. |
||||
|
||||
:param output_tensor: Tensor to be sent |
||||
:param next_rank: The rank of the recipient of the tensor |
||||
:param down_group: Communication group including the next member in pipeline parallel group |
||||
:type output_tensor: Tensor |
||||
:type next_rank: int, optional |
||||
:type down_group: ProcessGroup, optional |
||||
""" |
||||
if not gpc.is_last_rank(ParallelMode.PIPELINE): |
||||
_communicate(tensor_send_next=output_tensor, |
||||
next_rank=next_rank, |
||||
down_group=down_group) |
||||
|
||||
|
||||
def send_backward(input_tensor_grad, |
||||
prev_rank=None, |
||||
up_group=None): |
||||
"""Sends the grad tensor to the previous member in pipeline. |
||||
|
||||
:param input_tensor_grad: Tensor to be sent |
||||
:param prev_rank: The rank of the recipient of the tensor |
||||
:param up_group: Communication group including the previous member in pipeline parallel group |
||||
:type input_tensor_grad: Tensor |
||||
:type prev_rank: int, optional |
||||
:type up_group: ProcessGroup, optional |
||||
""" |
||||
if not gpc.is_first_rank(ParallelMode.PIPELINE): |
||||
_communicate(tensor_send_prev=input_tensor_grad, |
||||
prev_rank=prev_rank, |
||||
up_group=up_group) |
||||
|
||||
|
||||
def send_forward_recv_backward(output_tensor, |
||||
output_grad_shape, |
||||
recv_next=True, |
||||
next_rank=None, |
||||
down_group=None): |
||||
"""Batched communication operation. Sends the input tensor to the |
||||
next member in pipeline, while recieves the grad tensor from the |
||||
next member in pipeline. |
||||
|
||||
:param output_tensor: Tensor to be sent |
||||
:param output_grad_shape: The shape of the tensor to be recieved |
||||
:type output_tensor: Tensor |
||||
:type output_grad_shape: torch.Size |
||||
:return: The grad of output tensor in forward step |
||||
:rtype: Tensor |
||||
""" |
||||
if gpc.is_last_rank(ParallelMode.PIPELINE): |
||||
output_tensor_grad = None |
||||
else: |
||||
_, output_tensor_grad = _communicate(tensor_send_next=output_tensor, |
||||
recv_next=recv_next, |
||||
recv_next_shape=output_grad_shape, |
||||
next_rank=next_rank, |
||||
down_group=down_group) |
||||
return output_tensor_grad |
||||
|
||||
|
||||
def send_backward_recv_forward(input_tensor_grad, |
||||
input_tensor_shape, |
||||
recv_prev=True, |
||||
prev_rank=None, |
||||
up_group=None): |
||||
"""Batched communication operation. Sends the grad tensor to the |
||||
previous member in pipeline, while recieves the input tensor from the |
||||
previous member in pipeline. |
||||
|
||||
:param input_tensor_grad: Tensor to be sent |
||||
:param input_tensor_shape: The shape of the tensor to be recieved |
||||
:type input_tensor_grad: Tensor |
||||
:type input_tensor_shape: torch.Size |
||||
:return: The input tensor in forward step |
||||
:rtype: Tensor |
||||
""" |
||||
if gpc.is_first_rank(ParallelMode.PIPELINE): |
||||
input_tensor = None |
||||
else: |
||||
input_tensor, _ = _communicate(tensor_send_prev=input_tensor_grad, |
||||
recv_prev=recv_prev, |
||||
recv_prev_shape=input_tensor_shape, |
||||
prev_rank=prev_rank, |
||||
up_group=up_group) |
||||
return input_tensor |
||||
|
||||
|
||||
def send_forward_recv_forward(output_tensor, |
||||
input_tensor_shape, |
||||
recv_prev=True, |
||||
prev_rank=None, |
||||
next_rank=None, |
||||
up_group=None, |
||||
down_group=None): |
||||
"""Batched communication operation. Sends the input tensor to the |
||||
next member in pipeline, while recieves the input tensor from the |
||||
previous member in pipeline. |
||||
|
||||
:param output_tensor: Tensor to be sent |
||||
:param input_tensor_shape: The shape of the tensor to be recieved |
||||
:type output_tensor: Tensor |
||||
:type input_tensor_shape: torch.Size |
||||
:return: The input tensor in forward step |
||||
:rtype: Tensor |
||||
""" |
||||
input_tensor, _ = _communicate(tensor_send_next=output_tensor, |
||||
recv_prev=recv_prev, |
||||
recv_prev_shape=input_tensor_shape, |
||||
prev_rank=prev_rank, |
||||
next_rank=next_rank, |
||||
up_group=up_group, |
||||
down_group=down_group) |
||||
return input_tensor |
||||
|
||||
|
||||
def send_backward_recv_backward(input_tensor_grad, |
||||
output_grad_shape, |
||||
recv_next=True, |
||||
prev_rank=None, |
||||
next_rank=None, |
||||
up_group=None, |
||||
down_group=None): |
||||
"""Batched communication operation. Sends the grad tensor to the |
||||
previous member in pipeline, while recieves the grad tensor from the |
||||
next member in pipeline. |
||||
|
||||
:param input_tensor_grad: Tensor to be sent |
||||
:param output_grad_shape: The shape of the tensor to be recieved |
||||
:type input_tensor_grad: Tensor |
||||
:type output_grad_shape: torch.Size |
||||
:return: The grad of output tensor in forward step |
||||
:rtype: Tensor |
||||
""" |
||||
_, output_tensor_grad = _communicate(tensor_send_prev=input_tensor_grad, |
||||
recv_next=recv_next, |
||||
recv_next_shape=output_grad_shape, |
||||
prev_rank=prev_rank, |
||||
next_rank=next_rank, |
||||
up_group=up_group, |
||||
down_group=down_group) |
||||
return output_tensor_grad |
||||
|
||||
|
||||
def send_forward_backward_recv_forward_backward(output_tensor, |
||||
input_tensor_grad, |
||||
input_tensor_shape, |
||||
output_grad_shape, |
||||
recv_prev=True, |
||||
recv_next=True, |
||||
prev_rank=None, |
||||
next_rank=None, |
||||
up_group=None, |
||||
down_group=None): |
||||
"""Batched communication operation. Sends the input tensor to the next and |
||||
the grad tensor to the previous, while recieves the grad tensor from the |
||||
next and the input tensor from the previous. |
||||
|
||||
:param output_tensor: Tensor sent to the next |
||||
:param input_tensor_grad: Tensor sent to the previous |
||||
:param input_tensor_shape: The shape of the tensor recieved from the previous |
||||
:param output_grad_shape: The shape of the tensor recieved from the next |
||||
:type output_tensor: Tensor |
||||
:type input_tensor_grad: Tensor |
||||
:type input_tensor_shape: torch.Size |
||||
:type output_grad_shape: torch.Size |
||||
:return: (the input tensor in forward step, the grad of output tensor in forward step) |
||||
:rtype: (Tensor, Tensor) |
||||
""" |
||||
input_tensor, output_tensor_grad = _communicate( |
||||
tensor_send_next=output_tensor, |
||||
tensor_send_prev=input_tensor_grad, |
||||
recv_prev=recv_prev, |
||||
recv_next=recv_next, |
||||
recv_prev_shape=input_tensor_shape, |
||||
recv_next_shape=output_grad_shape, |
||||
prev_rank=prev_rank, |
||||
next_rank=next_rank, |
||||
up_group=up_group, |
||||
down_group=down_group) |
||||
return input_tensor, output_tensor_grad |
@ -0,0 +1,54 @@
|
||||
#!/usr/bin/env python |
||||
# -*- encoding: utf-8 -*- |
||||
|
||||
import torch |
||||
|
||||
from colossalai.context.parallel_mode import ParallelMode |
||||
from colossalai.core import global_context as gpc |
||||
from colossalai.utils import get_current_device, synchronize |
||||
|
||||
|
||||
def ring_forward(tensor_send_next: torch.Tensor, parallel_mode: ParallelMode): |
||||
"""Sends a tensor to the next member and recieves a tensor from the previous member. |
||||
This function returns the recieved tensor from the previous member. |
||||
|
||||
:param tensor_send_next: Tensor sent to next member |
||||
:param parallel_mode: Parallel group mode used in this communication |
||||
:type tensor_send_next: Tensor |
||||
:type parallel_mode: ParallelMode |
||||
:return: The tensor recieved from the previous |
||||
:rtype: Tensor |
||||
""" |
||||
buffer_shape = tensor_send_next.size() |
||||
|
||||
ops = [] |
||||
current_rank = gpc.get_global_rank() |
||||
|
||||
tensor_recv_prev = torch.empty(buffer_shape, |
||||
requires_grad=True, |
||||
device=get_current_device(), |
||||
dtype=tensor_send_next.dtype) |
||||
|
||||
# send to next rank |
||||
send_next_op = torch.distributed.P2POp( |
||||
torch.distributed.isend, tensor_send_next, |
||||
gpc.get_next_global_rank(parallel_mode)) |
||||
ops.append(send_next_op) |
||||
|
||||
# receive from prev rank |
||||
recv_prev_op = torch.distributed.P2POp( |
||||
torch.distributed.irecv, tensor_recv_prev, |
||||
gpc.get_prev_global_rank(parallel_mode)) |
||||
ops.append(recv_prev_op) |
||||
|
||||
if current_rank % 2 == 0: |
||||
ops = ops[::-1] |
||||
|
||||
reqs = torch.distributed.batch_isend_irecv(ops) |
||||
for req in reqs: |
||||
req.wait() |
||||
|
||||
# To protect against race condition when using batch_isend_irecv(). |
||||
synchronize() |
||||
|
||||
return tensor_recv_prev |
@ -0,0 +1,73 @@
|
||||
import torch |
||||
import torch.distributed as dist |
||||
|
||||
from colossalai.context.parallel_mode import ParallelMode |
||||
from colossalai.core import global_context as gpc |
||||
from colossalai.utils import get_current_device |
||||
|
||||
|
||||
def send_tensor_meta(tensor, need_meta=True, down_group=None): |
||||
"""Sends tensor meta information before sending a specific tensor. |
||||
Since the recipient must know the shape of the tensor in p2p communications, |
||||
meta information of the tensor should be sent before communications. This function |
||||
synchronizes with :func:`recv_tensor_meta`. |
||||
|
||||
:param tensor: Tensor to be sent |
||||
:param need_meta: If False, meta information won't be sent |
||||
:param down_group: Communication group including the next member in pipeline parallel group |
||||
:type tensor: Tensor |
||||
:type need_meta: bool, optional |
||||
:type down_group: ProcessGroup, optional |
||||
:return: False |
||||
:rtype: bool |
||||
""" |
||||
if need_meta: |
||||
rank = gpc.get_global_rank() |
||||
|
||||
if down_group is None: |
||||
down_group = gpc.get_group(ParallelMode.PIPELINE_NEXT) |
||||
|
||||
tensor_kwargs = {'dtype': torch.long, 'device': get_current_device()} |
||||
|
||||
send_shape = torch.tensor(tensor.size(), **tensor_kwargs) |
||||
send_ndims = torch.tensor(len(tensor.size()), **tensor_kwargs) |
||||
|
||||
dist.broadcast(send_ndims, src=rank, group=down_group) |
||||
dist.broadcast(send_shape, src=rank, group=down_group) |
||||
|
||||
return False |
||||
|
||||
|
||||
def recv_tensor_meta(tensor_shape, prev_rank=None, up_group=None): |
||||
"""Recieves tensor meta information before recieving a specific tensor. |
||||
Since the recipient must know the shape of the tensor in p2p communications, |
||||
meta information of the tensor should be recieved before communications. This function |
||||
synchronizes with :func:`send_tensor_meta`. |
||||
|
||||
:param tensor_shape: The shape of the tensor to be recieved |
||||
:param prev_rank: The rank of the source of the tensor |
||||
:param up_group: Communication group including the previous member in pipeline parallel group |
||||
:type tensor_shape: torch.Size |
||||
:type prev_rank: int, optional |
||||
:type up_group: ProcessGroup, optional |
||||
:return: The shape of the tensor to be recieved |
||||
:rtype: torch.Size |
||||
""" |
||||
if tensor_shape is None: |
||||
if prev_rank is None: |
||||
prev_rank = gpc.get_prev_global_rank( |
||||
ParallelMode.PIPELINE) |
||||
if up_group is None: |
||||
up_group = gpc.get_group(ParallelMode.PIPELINE_PREV) |
||||
|
||||
tensor_kwargs = {'dtype': torch.long, 'device': get_current_device()} |
||||
|
||||
recv_ndims = torch.empty((), **tensor_kwargs) |
||||
dist.broadcast(recv_ndims, src=prev_rank, group=up_group) |
||||
|
||||
recv_shape = torch.empty(recv_ndims, **tensor_kwargs) |
||||
dist.broadcast(recv_shape, src=prev_rank, group=up_group) |
||||
|
||||
tensor_shape = torch.Size(recv_shape) |
||||
|
||||
return tensor_shape |
@ -0,0 +1,31 @@
|
||||
#!/usr/bin/env python |
||||
# -*- encoding: utf-8 -*- |
||||
|
||||
ALLOWED_MODES = [None, '1d', '2d', '2.5d', '3d', 'sequence'] |
||||
|
||||
# intializer |
||||
INITIALIZER_MAPPING = { |
||||
'data': 'Initializer_Data', |
||||
'tensor': 'Initializer_Tensor', |
||||
'pipeline': 'Initializer_Pipeline', |
||||
'embedding': 'Initializer_Embedding', |
||||
'1d': 'Initializer_1D', |
||||
'2d': 'Initializer_2D', |
||||
'2.5d': 'Initializer_2p5D', |
||||
'3d': 'Initializer_3D', |
||||
'sequence': 'Initializer_Sequence' |
||||
} |
||||
|
||||
# 2D paralllel |
||||
SUMMA_DIM = 'SUMMA_DIM' |
||||
|
||||
# 2.5D paralllel |
||||
TESSERACT_DIM = 'TESSERACT_DIM' |
||||
TESSERACT_DEP = 'TESSERACT_DEP' |
||||
|
||||
# 3D parallel |
||||
DEPTH_3D = 'DEPTH_3D' |
||||
|
||||
# Tensor parallel attributes |
||||
IS_TENSOR_PARALLEL = 'is_tensor_parallel' |
||||
TENSOR_PARALLEL_ATTRIBUTES = [IS_TENSOR_PARALLEL] |
@ -0,0 +1,5 @@
|
||||
from .config import Config |
||||
from .parallel_context import ParallelContext |
||||
from .parallel_context import ParallelMode |
||||
from .process_group_initializer import * |
||||
from .random import * |
@ -0,0 +1,70 @@
|
||||
import math |
||||
|
||||
|
||||
def set_parallel_size(obj, config: dict, key: str, attr_name: str): |
||||
if key in config: |
||||
ele = config[key] |
||||
if isinstance(ele, int): |
||||
setattr(obj, attr_name, ele) |
||||
elif isinstance(ele, dict): |
||||
setattr(obj, attr_name, ele['size']) |
||||
else: |
||||
raise NotImplementedError( |
||||
f"Parallel configuration does not support this kind of argument, please use int or dict" |
||||
) |
||||
|
||||
|
||||
def add_tensor_pg(pg_init, mode, size, depth=None): |
||||
if mode == '1d': |
||||
pg_init.append(dict( |
||||
type='Initializer1D', |
||||
parallel_size=size |
||||
)) |
||||
elif mode == '2d': |
||||
dim = math.floor(math.sqrt(size)) |
||||
pg_init.append(dict( |
||||
type='Initializer2D_Col', |
||||
summa_dim=dim |
||||
)) |
||||
pg_init.append(dict( |
||||
type='Initializer2D_Row', |
||||
summa_dim=dim |
||||
)) |
||||
elif mode == '2.5d': |
||||
dim = math.floor(math.sqrt(size // depth)) |
||||
pg_init.append(dict( |
||||
type='Initializer_Tesseract_ROW', |
||||
tesseract_dim=dim, |
||||
tesseract_dep=depth |
||||
)) |
||||
pg_init.append(dict( |
||||
type='Initializer_Tesseract_COL', |
||||
tesseract_dim=dim, |
||||
tesseract_dep=depth |
||||
)) |
||||
pg_init.append(dict( |
||||
type='Initializer_Tesseract_DEP', |
||||
tesseract_dim=dim, |
||||
tesseract_dep=depth |
||||
)) |
||||
pg_init.append(dict( |
||||
type='Initializer_Tesseract_XZ', |
||||
tesseract_dim=dim, |
||||
tesseract_dep=depth |
||||
)) |
||||
elif mode == '3d': |
||||
dim = math.floor(math.pow(size, 1.0 / 3.0) + 0.5) |
||||
pg_init.append(dict( |
||||
type='ParallelInitializer3D_Input', |
||||
depth=dim |
||||
)) |
||||
pg_init.append(dict( |
||||
type='ParallelInitializer3D_Weight', |
||||
depth=dim |
||||
)) |
||||
pg_init.append(dict( |
||||
type='ParallelInitializer3D_Output', |
||||
depth=dim |
||||
)) |
||||
else: |
||||
raise NotImplementedError("This kind of tensor splitting has not been implemented yet") |
@ -0,0 +1,99 @@
|
||||
#!/usr/bin/env python |
||||
# -*- encoding: utf-8 -*- |
||||
|
||||
import inspect |
||||
import sys |
||||
from importlib.machinery import SourceFileLoader |
||||
from pathlib import Path |
||||
|
||||
|
||||
class Config(dict): |
||||
"""This is a wrapper class for dict objects so that values of which can be |
||||
accessed as attributes. |
||||
|
||||
:param config: The dict object to be wrapped |
||||
:type config: dict |
||||
""" |
||||
|
||||
def __init__(self, config: dict = None): |
||||
if config is not None: |
||||
for k, v in config.items(): |
||||
self._add_item(k, v) |
||||
|
||||
def __missing__(self, key): |
||||
raise KeyError(key) |
||||
|
||||
def __getattr__(self, key): |
||||
try: |
||||
value = super(Config, self).__getitem__(key) |
||||
return value |
||||
except KeyError: |
||||
raise AttributeError(key) |
||||
|
||||
def __setattr__(self, key, value): |
||||
super(Config, self).__setitem__(key, value) |
||||
|
||||
def _add_item(self, key, value): |
||||
if isinstance(value, dict): |
||||
self.__setattr__(key, Config(value)) |
||||
else: |
||||
self.__setattr__(key, value) |
||||
|
||||
def update(self, config): |
||||
assert isinstance(config, (Config, dict)), 'can only update dictionary or Config objects.' |
||||
for k, v in config.items(): |
||||
self._add_item(k, v) |
||||
return self |
||||
|
||||
@staticmethod |
||||
def from_file(filename: str): |
||||
"""Reads a python file and constructs a corresponding :class:`Config` object. |
||||
|
||||
:param filename: Name of the file to construct the return object |
||||
:type filename: str |
||||
:raises AssertionError: Raises an AssertionError if the file does not exist, or the file |
||||
is not .py file |
||||
:return: A :class:`Config` object constructed with information in the file |
||||
:rtype: :class:`Config` |
||||
""" |
||||
|
||||
# check config path |
||||
if isinstance(filename, str): |
||||
filepath = Path(filename).absolute() |
||||
elif isinstance(filename, Path): |
||||
filepath = filename.absolute() |
||||
|
||||
assert filepath.exists(), f'{filename} is not found, please check your configuration path' |
||||
|
||||
# check extension |
||||
extension = filepath.suffix |
||||
assert extension == '.py', 'only .py files are supported' |
||||
|
||||
# import the config as module |
||||
remove_path = False |
||||
if filepath.parent not in sys.path: |
||||
sys.path.insert(0, (filepath)) |
||||
remove_path = True |
||||
|
||||
module_name = filepath.stem |
||||
source_file = SourceFileLoader(fullname=str(module_name), path=str(filepath)) |
||||
module = source_file.load_module() |
||||
|
||||
# load into config |
||||
config = Config() |
||||
|
||||
for k, v in module.__dict__.items(): |
||||
if k.startswith('__') or inspect.ismodule(v) or inspect.isclass(v): |
||||
continue |
||||
else: |
||||
config._add_item(k, v) |
||||
|
||||
# TODO: replace with logger warning here when logger is done |
||||
print('warning: variables which starts with __, is a module or class declaration are omitted') |
||||
|
||||
# remove module |
||||
del sys.modules[module_name] |
||||
if remove_path: |
||||
sys.path.pop(0) |
||||
|
||||
return config |
@ -0,0 +1,454 @@
|
||||
#!/usr/bin/env python |
||||
# -*- encoding: utf-8 -*- |
||||
|
||||
import os |
||||
import random |
||||
from typing import Union |
||||
|
||||
import numpy as np |
||||
import torch |
||||
import torch.distributed as dist |
||||
|
||||
from colossalai.constants import ALLOWED_MODES, INITIALIZER_MAPPING |
||||
from colossalai.context.config import Config |
||||
from colossalai.registry import DIST_GROUP_INITIALIZER |
||||
from ._utils import set_parallel_size |
||||
from .parallel_mode import ParallelMode |
||||
from .random import add_seed, get_seeds, set_mode |
||||
|
||||
|
||||
class ParallelContext: |
||||
"""This class provides interface functions for users to get the parallel context, |
||||
such as the global rank, the local rank, the world size, etc. of each device. |
||||
|
||||
:param args: The distributed arguments in the system |
||||
:type args: dict |
||||
""" |
||||
|
||||
def __init__(self, args=None): |
||||
# distributed settings |
||||
self._global_ranks = dict() |
||||
self._local_ranks = dict() |
||||
self._world_sizes = dict() |
||||
self._groups = dict() |
||||
self._ranks_in_group = dict() |
||||
|
||||
# load config from file |
||||
self._dist_args = args |
||||
self._config = None |
||||
|
||||
# default 3D parallel args, will be overwritten during process group intialization |
||||
self.world_size = 1 |
||||
self.data_parallel_size = 1 |
||||
self.pipeline_parallel_size = 1 |
||||
self.tensor_parallel_size = 1 |
||||
|
||||
@property |
||||
def config(self): |
||||
return self._config |
||||
|
||||
def load_config(self, config: Union[dict, str]): |
||||
"""Loads the configuration from either a dict or a file. |
||||
|
||||
:param config: Either a dict containing the configuration information or the filename |
||||
of a file containing the configuration information |
||||
:type config: dict or str |
||||
:raises TypeError: Raises a TypeError if `config` is neither a dict or a str |
||||
""" |
||||
if isinstance(config, str): |
||||
self._config = Config.from_file(config) |
||||
elif isinstance(config, dict): |
||||
self._config = Config(config) |
||||
else: |
||||
raise TypeError("Invalid type for config, only dictionary or string is supported") |
||||
|
||||
def set_dist_args(self, args): |
||||
"""Sets the distributed arguments. |
||||
|
||||
:param args: The distributed arguments in the system |
||||
:type args: dict |
||||
""" |
||||
self._dist_args = args |
||||
|
||||
@staticmethod |
||||
def _check_parallel_mode(parallel_mode: ParallelMode): |
||||
assert isinstance(parallel_mode, ParallelMode) |
||||
|
||||
def get_global_rank(self): |
||||
"""Returns the global rank of the current device. |
||||
|
||||
:return: The global rank of the current device |
||||
:rtype: int |
||||
""" |
||||
return self._global_ranks[ParallelMode.GLOBAL] |
||||
|
||||
def add_global_rank(self, parallel_mode: ParallelMode, rank: int): |
||||
"""Adds the global rank of the current device for `parallel_mode` to the context. |
||||
|
||||
:param parallel_mode: The parallel mode for the rank |
||||
:type parallel_mode: :class:`colossalai.context.ParallelMode` |
||||
:param rank: The rank to be added |
||||
:type rank: int |
||||
:raises AssertionError: Raises an AssertionError if `parallel_mode` is not an instance |
||||
of :class:`colossalai.context.ParallelMode` |
||||
""" |
||||
self._check_parallel_mode(parallel_mode) |
||||
self._global_ranks[parallel_mode] = rank |
||||
|
||||
def get_local_rank(self, parallel_mode: ParallelMode): |
||||
"""Returns the local rank of the current device. |
||||
|
||||
:param parallel_mode: The chosen parallel mode |
||||
:type parallel_mode: :class:`colossalai.context.ParallelMode` |
||||
:raises AssertionError: Raises an AssertionError if `parallel_mode` is not an instance |
||||
of :class:`colossalai.context.ParallelMode` |
||||
:return: The local rank of the current device for `parallel_mode` |
||||
:rtype: int |
||||
""" |
||||
self._check_parallel_mode(parallel_mode) |
||||
return self._local_ranks[parallel_mode] |
||||
|
||||
def add_local_rank(self, parallel_mode: ParallelMode, rank: int): |
||||
"""Adds the local rank of the current device for `parallel_mode` to the context. |
||||
|
||||
:param parallel_mode: The parallel mode for the rank |
||||
:type parallel_mode: :class:`colossalai.context.ParallelMode` |
||||
:param rank: The rank to be added |
||||
:type rank: int |
||||
:raises AssertionError: Raises an AssertionError if `parallel_mode` is not an instance |
||||
of :class:`colossalai.context.ParallelMode` |
||||
""" |
||||
self._check_parallel_mode(parallel_mode) |
||||
self._local_ranks[parallel_mode] = rank |
||||
|
||||
def get_next_global_rank(self, parallel_mode: ParallelMode): |
||||
"""Returns the global rank of the next device. |
||||
|
||||
:param parallel_mode: The chosen parallel mode |
||||
:type parallel_mode: :class:`colossalai.context.ParallelMode` |
||||
:raises AssertionError: Raises an AssertionError if `parallel_mode` is not an instance |
||||
of :class:`colossalai.context.ParallelMode` |
||||
:return: The global rank of the next device for `parallel_mode` |
||||
:rtype: int |
||||
""" |
||||
self._check_parallel_mode(parallel_mode) |
||||
|
||||
# get rank and world size |
||||
local_rank = self.get_local_rank(parallel_mode) |
||||
world_size = self.get_world_size(parallel_mode) |
||||
ranks_in_group = self.get_ranks_in_group(parallel_mode) |
||||
|
||||
return ranks_in_group[(local_rank + 1) % world_size] |
||||
|
||||
def get_prev_global_rank(self, parallel_mode: ParallelMode): |
||||
"""Returns the global rank of the previous device. |
||||
|
||||
:param parallel_mode: The chosen parallel mode |
||||
:type parallel_mode: :class:`colossalai.context.ParallelMode` |
||||
:raises AssertionError: Raises an AssertionError if `parallel_mode` is not an instance |
||||
of :class:`colossalai.context.ParallelMode` |
||||
:return: The global rank of the previous device for `parallel_mode` |
||||
:rtype: int |
||||
""" |
||||
self._check_parallel_mode(parallel_mode) |
||||
|
||||
# get rank and world size |
||||
local_rank = self.get_local_rank(parallel_mode) |
||||
world_size = self.get_world_size(parallel_mode) |
||||
ranks_in_group = self.get_ranks_in_group(parallel_mode) |
||||
|
||||
return ranks_in_group[(local_rank - 1) % world_size] |
||||
|
||||
def is_first_rank(self, parallel_mode: ParallelMode): |
||||
"""Returns a boolean value indicating whether the current device is the first one |
||||
among its group for `parallel_mode`. |
||||
|
||||
:param parallel_mode: The chosen parallel mode |
||||
:type parallel_mode: :class:`colossalai.context.ParallelMode` |
||||
:raises AssertionError: Raises an AssertionError if `parallel_mode` is not an instance |
||||
of :class:`colossalai.context.ParallelMode` |
||||
:return: a boolean value indicating whether the current device is the first one |
||||
among its group for `parallel_mode` |
||||
:rtype: bool |
||||
""" |
||||
rank = self.get_local_rank(parallel_mode) |
||||
return rank == 0 |
||||
|
||||
def is_last_rank(self, parallel_mode: ParallelMode): |
||||
"""Returns a boolean value indicating whether the current device is the last one |
||||
among its group for `parallel_mode`. |
||||
|
||||
:param parallel_mode: The chosen parallel mode |
||||
:type parallel_mode: :class:`colossalai.context.ParallelMode` |
||||
:raises AssertionError: Raises an AssertionError if `parallel_mode` is not an instance |
||||
of :class:`colossalai.context.ParallelMode` |
||||
:return: a boolean value indicating whether the current device is the last one |
||||
among its group for `parallel_mode` |
||||
:rtype: bool |
||||
""" |
||||
rank = self.get_local_rank(parallel_mode) |
||||
world_size = self.get_world_size(parallel_mode) |
||||
return rank == world_size - 1 |
||||
|
||||
def get_world_size(self, parallel_mode: ParallelMode): |
||||
"""Returns the world size for `parallel_mode`. |
||||
|
||||
:param parallel_mode: The chosen parallel mode |
||||
:type parallel_mode: :class:`colossalai.context.ParallelMode` |
||||
:raises AssertionError: Raises an AssertionError if `parallel_mode` is not an instance |
||||
of :class:`colossalai.context.ParallelMode` |
||||
:return: The world size for `parallel_mode` |
||||
:rtype: int |
||||
""" |
||||
self._check_parallel_mode(parallel_mode) |
||||
return self._world_sizes[parallel_mode] |
||||
|
||||
def add_world_size(self, parallel_mode: ParallelMode, world_size: int): |
||||
"""Adds world size for `parallel_mode`. |
||||
|
||||
:param parallel_mode: The chosen parallel mode |
||||
:type parallel_mode: :class:`colossalai.context.ParallelMode` |
||||
:param world_size: The world size to be added |
||||
:type world_size: int |
||||
:raises AssertionError: Raises an AssertionError if `parallel_mode` is not an instance |
||||
of :class:`colossalai.context.ParallelMode` |
||||
""" |
||||
self._check_parallel_mode(parallel_mode) |
||||
self._world_sizes[parallel_mode] = world_size |
||||
|
||||
def get_group(self, parallel_mode: ParallelMode): |
||||
"""Returns the group of the current device for `parallel_mode`. |
||||
|
||||
:param parallel_mode: The chosen parallel mode |
||||
:type parallel_mode: :class:`colossalai.context.ParallelMode` |
||||
:raises AssertionError: Raises an AssertionError if `parallel_mode` is not an instance |
||||
of :class:`colossalai.context.ParallelMode` |
||||
:return: The group of the current device for `parallel_mode` |
||||
:rtype: torch.distributed.ProcessGroup |
||||
""" |
||||
self._check_parallel_mode(parallel_mode) |
||||
return self._groups[parallel_mode] |
||||
|
||||
def add_group(self, parallel_mode: ParallelMode, group: dist.ProcessGroup): |
||||
"""Adds the group of the current device for `parallel_mode`. |
||||
|
||||
:param parallel_mode: The chosen parallel mode |
||||
:type parallel_mode: :class:`colossalai.context.ParallelMode` |
||||
:param group: The group to be added |
||||
:type group: torch.distributed.ProcessGroup |
||||
:raises AssertionError: Raises an AssertionError if `parallel_mode` is not an instance |
||||
of :class:`colossalai.context.ParallelMode` |
||||
""" |
||||
self._check_parallel_mode(parallel_mode) |
||||
self._groups[parallel_mode] = group |
||||
|
||||
def get_ranks_in_group(self, parallel_mode: ParallelMode): |
||||
"""Returns the rank of the current device for `parallel_mode` in the group. |
||||
|
||||
:param parallel_mode: The chosen parallel mode |
||||
:type parallel_mode: :class:`colossalai.context.ParallelMode` |
||||
:raises AssertionError: Raises an AssertionError if `parallel_mode` is not an instance |
||||
of :class:`colossalai.context.ParallelMode` |
||||
:return: the rank of the current device for `parallel_mode` in the group |
||||
:rtype: int |
||||
""" |
||||
self._check_parallel_mode(parallel_mode) |
||||
return self._ranks_in_group[parallel_mode] |
||||
|
||||
def add_ranks_in_group(self, parallel_mode: ParallelMode, ranks: list): |
||||
"""Adds the ranks of the current device for `parallel_mode` in the group. |
||||
|
||||
:param parallel_mode: The chosen parallel mode |
||||
:type parallel_mode: :class:`colossalai.context.ParallelMode` |
||||
:param ranks: List of ranks to be added |
||||
:type ranks: list |
||||
:raises AssertionError: Raises an AssertionError if `parallel_mode` is not an instance |
||||
of :class:`colossalai.context.ParallelMode` |
||||
""" |
||||
self._check_parallel_mode(parallel_mode) |
||||
self._ranks_in_group[parallel_mode] = ranks |
||||
|
||||
def init_global_dist(self, addr=None, port=None): |
||||
"""Initializes the global distributed environment. |
||||
|
||||
:param addr: The IP address of the current device |
||||
:type addr: str, optional |
||||
:param port: The port to be used in the system of the current device |
||||
:type port: int, optional |
||||
""" |
||||
# get config |
||||
rank = self._dist_args.local_rank |
||||
world_size = self._dist_args.world_size |
||||
# default env config, overwrite by exporting |
||||
# them in your bash script |
||||
addr = os.getenv('MASTER_ADDR', 'localhost') if addr is None else addr |
||||
port = os.getenv('MASTER_PORT', '8008') if port is None else port |
||||
init_method = f'tcp://{addr}:{port}' |
||||
|
||||
dist.init_process_group(backend=self._dist_args.backend, |
||||
rank=rank, |
||||
world_size=world_size, |
||||
init_method=init_method) |
||||
|
||||
# None will give the default global process group for pytorch dist operations |
||||
self._register_dist(rank, world_size, None, |
||||
list(range(world_size)), ParallelMode.GLOBAL) |
||||
self._global_ranks[ParallelMode.GLOBAL] = rank |
||||
|
||||
def _register_dist(self, local_rank, world_size, |
||||
process_group, ranks_in_group, mode): |
||||
self.add_local_rank(mode, local_rank) |
||||
self.add_world_size(mode, world_size) |
||||
self.add_group(mode, process_group) |
||||
self.add_ranks_in_group(mode, ranks_in_group) |
||||
|
||||
def check_sanity(self): |
||||
"""Checks sanity of the parallel context. |
||||
|
||||
:raises AssertionError: Raises an AssertionError if the world size does not equal to the product |
||||
of data paralle size, pipeline parallel size and tensor parallel size |
||||
""" |
||||
dps = self.data_parallel_size |
||||
pps = self.pipeline_parallel_size |
||||
tps = self.tensor_parallel_size |
||||
ws = self.world_size |
||||
assert ws == dps * pps * tps, f"Expected the world size {ws} to be equal to data parallel size ({dps}) * pipeline parallel size ({pps}) * tensor parallel size ({tps})" |
||||
|
||||
def init_parallel_groups(self): |
||||
"""Initializes the parallel groups. |
||||
|
||||
:raises AssertionError: Raises an AssertionError if the field paralle is not present in the config file |
||||
""" |
||||
|
||||
# get rank and world size |
||||
rank = self.get_global_rank() |
||||
world_size = self.get_world_size(ParallelMode.GLOBAL) |
||||
self.world_size = world_size |
||||
|
||||
assert hasattr(self.config, 'parallel'), 'Expected the field parallel to be present in the config file' |
||||
|
||||
# set parallel size as attributes for global context |
||||
parallel_config = self.config.parallel |
||||
set_parallel_size(self, parallel_config, 'pipeline', |
||||
'pipeline_parallel_size') |
||||
set_parallel_size(self, parallel_config, 'tensor', |
||||
'tensor_parallel_size') |
||||
|
||||
# the user should not set the data parallel size manually |
||||
# instead, it should be calculated based on other parallel config |
||||
self.data_parallel_size = self.world_size // (self.pipeline_parallel_size * self.tensor_parallel_size) |
||||
|
||||
# get the tensor parallel mode and check |
||||
tensor_parallel_mode = parallel_config['tensor'].get('mode', None) |
||||
assert tensor_parallel_mode in ALLOWED_MODES, f"mode in the parallel config must be set to one of {ALLOWED_MODES}" |
||||
self.check_sanity() |
||||
|
||||
pg_init = [] |
||||
# LSG: init data parallel process group for compatibility with other parallel module such as zero |
||||
pg_init.append(dict(type=INITIALIZER_MAPPING['data'])) |
||||
|
||||
if self.pipeline_parallel_size > 1: |
||||
pg_init.append(dict(type=INITIALIZER_MAPPING['pipeline'])) |
||||
pg_init.append(dict(type=INITIALIZER_MAPPING['tensor'])) |
||||
|
||||
# init specific tensor parallel group |
||||
if tensor_parallel_mode is not None: |
||||
tensor_parallel_cfg = parallel_config['tensor'].copy() |
||||
|
||||
# remove duplicate parameters |
||||
tensor_parallel_cfg.pop('mode') |
||||
tensor_parallel_cfg.pop('size') |
||||
|
||||
# add this config to initialize later |
||||
pg_init.append(dict(type=INITIALIZER_MAPPING[tensor_parallel_mode.lower()], **tensor_parallel_cfg)) |
||||
|
||||
# run initialization of different process groups |
||||
for initializer_cfg in pg_init: |
||||
cfg = initializer_cfg.copy() |
||||
initializer_type = cfg.pop('type') |
||||
initializer = DIST_GROUP_INITIALIZER.get_module(initializer_type)( |
||||
rank, world_size, self.config, |
||||
self.data_parallel_size, |
||||
self.pipeline_parallel_size, |
||||
self.tensor_parallel_size, |
||||
**cfg) |
||||
parallel_setting = initializer.init_dist_group() |
||||
if isinstance(parallel_setting, list): |
||||
for args in parallel_setting: |
||||
self._register_dist(*args) |
||||
else: |
||||
self._register_dist(*parallel_setting) |
||||
|
||||
def is_initialized(self, parallel_mode: ParallelMode): |
||||
"""Returns a boolean value indicating whether `parallel_mode` is initialized |
||||
in the current system. |
||||
|
||||
:param parallel_mode: The chosen parallel mode |
||||
:type parallel_mode: :class:`colossalai.context.ParallelMode` |
||||
:return: a boolean value indicating whether `parallel_mode` is initialized |
||||
in the current system |
||||
:rtype: bool |
||||
""" |
||||
return parallel_mode in self._groups |
||||
|
||||
def destroy(self): |
||||
"""Destroys the current distributed parallel environment. |
||||
""" |
||||
for mode, group in self._groups.items(): |
||||
if mode is not ParallelMode.GLOBAL: |
||||
dist.destroy_process_group(group) |
||||
# destroy global process group |
||||
dist.destroy_process_group() |
||||
|
||||
def set_device(self): |
||||
"""Sets distributed processes to be bound to devices. |
||||
""" |
||||
devices_per_node = torch.cuda.device_count() |
||||
global_rank = self.get_global_rank() |
||||
device = global_rank % devices_per_node |
||||
torch.cuda.set_device(device) |
||||
print(f'process rank {global_rank} is bound to device {device}') |
||||
|
||||
def set_seed(self): |
||||
"""Sets seeds for all random libraries. |
||||
""" |
||||
if hasattr(self.config, 'seed'): |
||||
seed = getattr(self.config, 'seed') |
||||
else: |
||||
seed = 2 # default seed |
||||
|
||||
random.seed(seed) |
||||
np.random.seed(seed) |
||||
torch.manual_seed(seed) |
||||
|
||||
global_rank = self.get_global_rank() |
||||
|
||||
if torch.cuda.is_available(): |
||||
# create random seed for different parallel modes |
||||
# data parallel seed are kept the same |
||||
parallel_seed = seed |
||||
add_seed(ParallelMode.DATA, parallel_seed) |
||||
|
||||
# model parallel seeds are different across ranks |
||||
pipeline_offset = self._local_ranks.get(ParallelMode.PIPELINE, 0) |
||||
|
||||
# add seed for data parallel and tensor parallel only |
||||
if self.is_initialized(ParallelMode.TENSOR): |
||||
tp_rank = self.get_local_rank(ParallelMode.TENSOR) |
||||
# 100 is only to increase the diff in seeds between pipeline stages |
||||
tp_rank_with_offset = tp_rank + pipeline_offset * 1024 |
||||
tp_seed = seed + tp_rank_with_offset |
||||
add_seed(ParallelMode.TENSOR, tp_seed) |
||||
|
||||
set_mode(ParallelMode.DATA) |
||||
seeds = get_seeds() |
||||
seed_str = ', '.join([f'{k}: {v}' for k, v in seeds.items()]) |
||||
|
||||
print(f"initialized seed on rank {global_rank}, " |
||||
f"numpy: {seed}, python random: {seed}, {seed_str}," |
||||
f"the default parallel seed is {ParallelMode.DATA}.", flush=True) |
||||
else: |
||||
print(f"initialized seed on rank {global_rank}, " |
||||
f"numpy: {seed}, python random: {seed}, pytorch: {seed}", flush=True) |
||||
print('WARNING: CUDA is not available, thus CUDA RNG cannot be used to track CUDA random number states', |
||||
flush=True) |
@ -0,0 +1,44 @@
|
||||
#!/usr/bin/env python |
||||
# -*- encoding: utf-8 -*- |
||||
|
||||
from enum import Enum |
||||
|
||||
|
||||
# parallel modes |
||||
class ParallelMode(Enum): |
||||
"""This is an enumeration class containing all possible parallel modes. |
||||
""" |
||||
|
||||
GLOBAL = 'global' |
||||
|
||||
# common parallel |
||||
DATA = 'data' |
||||
|
||||
# pipeline parallel |
||||
PIPELINE = 'pipe' |
||||
PIPELINE_PREV = 'pipe_prev' |
||||
PIPELINE_NEXT = 'pipe_next' |
||||
|
||||
# containing all ranks in tensor parallel |
||||
TENSOR = 'tensor' |
||||
|
||||
# sequence parallel |
||||
SEQUENCE = 'sequence' |
||||
|
||||
# 1D Parallel |
||||
PARALLEL_1D = '1d' |
||||
|
||||
# 2D parallel |
||||
PARALLEL_2D_ROW = '2d_row' |
||||
PARALLEL_2D_COL = '2d_col' |
||||
|
||||
# 3D parallel |
||||
PARALLEL_3D_INPUT = '3d_input' |
||||
PARALLEL_3D_WEIGHT = '3d_weight' |
||||
PARALLEL_3D_OUTPUT = '3d_output' |
||||
|
||||
# 2.5D parallel |
||||
PARALLEL_2P5D_ROW = '2p5d_row' |
||||
PARALLEL_2P5D_COL = '2p5d_col' |
||||
PARALLEL_2P5D_DEP = '2p5d_dep' |
||||
PARALLEL_2P5D_XZ = '2p5d_xz' |
@ -0,0 +1,15 @@
|
||||
from .initializer_1d import Initializer_1D |
||||
from .initializer_2d import Initializer_2D |
||||
from .initializer_2p5d import Initializer_2p5D |
||||
from .initializer_3d import Initializer_3D |
||||
from .initializer_data import Initializer_Data |
||||
from .initializer_pipeline import Initializer_Pipeline |
||||
from .initializer_sequence import Initializer_Sequence |
||||
from .initializer_tensor import Initializer_Tensor |
||||
from .process_group_initializer import ProcessGroupInitializer |
||||
|
||||
__all__ = [ |
||||
'Initializer_Tensor', 'Initializer_Sequence', 'Initializer_Pipeline', |
||||
'Initializer_Data', 'Initializer_2p5D', 'Initializer_2D', 'Initializer_3D', |
||||
'Initializer_1D', 'ProcessGroupInitializer' |
||||
] |
@ -0,0 +1,44 @@
|
||||
#!/usr/bin/env python |
||||
# -*- encoding: utf-8 -*- |
||||
|
||||
import torch.distributed as dist |
||||
|
||||
from colossalai.context import Config |
||||
from colossalai.core import global_context as gpc |
||||
from colossalai.registry import DIST_GROUP_INITIALIZER |
||||
from .process_group_initializer import ProcessGroupInitializer |
||||
from ..parallel_mode import ParallelMode |
||||
|
||||
|
||||
@DIST_GROUP_INITIALIZER.register_module |
||||
class Initializer_1D(ProcessGroupInitializer): |
||||
'''A ProcessGroupInitializer for 1d tensor parallelism. |
||||
''' |
||||
|
||||
def __init__(self, *args, **kwargs): |
||||
super().__init__(*args, **kwargs) |
||||
self.num_group = self.world_size // self.tensor_parallel_size |
||||
|
||||
def init_dist_group(self): |
||||
'''Initialize 1D tensor parallel groups, and assign local_ranks and groups to each gpu. |
||||
|
||||
:return: (local_rank, group_world_size, process_group, ranks_in_group, mode) |
||||
:rtype: tuple |
||||
''' |
||||
local_rank = None |
||||
ranks_in_group = None |
||||
process_group = None |
||||
group_world_size = None |
||||
mode = ParallelMode.PARALLEL_1D |
||||
|
||||
for i in range(self.num_group): |
||||
ranks = [i * self.tensor_parallel_size + j for j in range(self.tensor_parallel_size)] |
||||
group = dist.new_group(ranks) |
||||
|
||||
if self.rank in ranks: |
||||
local_rank = ranks.index(self.rank) |
||||
group_world_size = len(ranks) |
||||
process_group = group |
||||
ranks_in_group = ranks |
||||
|
||||
return local_rank, group_world_size, process_group, ranks_in_group, mode |
@ -0,0 +1,123 @@
|
||||
import math |
||||
import os |
||||
|
||||
import torch.distributed as dist |
||||
|
||||
from colossalai.constants import SUMMA_DIM |
||||
from colossalai.registry import DIST_GROUP_INITIALIZER |
||||
from .process_group_initializer import ProcessGroupInitializer |
||||
from ..parallel_mode import ParallelMode |
||||
|
||||
|
||||
def _check_summa_env_var(summa_dim): |
||||
# check environment variable for SUMMA |
||||
env_summa_dim = os.environ.get(SUMMA_DIM, None) |
||||
|
||||
if env_summa_dim: |
||||
assert int(env_summa_dim) == summa_dim, \ |
||||
'SUMMA_DIM has been set in the current environment and ' \ |
||||
'does not match with the value passed to this initialized' |
||||
else: |
||||
os.environ[SUMMA_DIM] = str(summa_dim) |
||||
|
||||
|
||||
class Initializer_2D_Row(ProcessGroupInitializer): |
||||
'''2d tensor parallel initialization among rows. |
||||
''' |
||||
|
||||
def __init__(self, num_group, summa_dim, *args, **kwargs): |
||||
super(Initializer_2D_Row, self).__init__(*args, **kwargs) |
||||
self.num_group = num_group |
||||
self.summa_dim = summa_dim |
||||
|
||||
def init_dist_group(self): |
||||
'''Initialize 2D tensor row parallel groups, and assign local_ranks and groups to each gpu. |
||||
|
||||
:return: 2D tensor row parallelism's information |
||||
:rtype: tuple(local_rank, group_world_size, process_group, ranks_in_group, mode) |
||||
''' |
||||
local_rank = None |
||||
ranks_in_group = None |
||||
process_group = None |
||||
group_world_size = None |
||||
mode = ParallelMode.PARALLEL_2D_ROW |
||||
|
||||
for i in range(self.num_group): |
||||
for j in range(self.summa_dim): |
||||
ranks = [i * self.tensor_parallel_size + j * self.summa_dim + k |
||||
for k in range(self.summa_dim)] |
||||
group = dist.new_group(ranks) |
||||
|
||||
if self.rank in ranks: |
||||
local_rank = ranks.index(self.rank) |
||||
group_world_size = len(ranks) |
||||
process_group = group |
||||
ranks_in_group = ranks |
||||
|
||||
return local_rank, group_world_size, process_group, ranks_in_group, mode |
||||
|
||||
|
||||
class Initializer_2D_Col(ProcessGroupInitializer): |
||||
'''2d tensor parallel initialization among cols. |
||||
''' |
||||
|
||||
def __init__(self, num_group, summa_dim, *args, **kwargs): |
||||
super(Initializer_2D_Col, self).__init__(*args, **kwargs) |
||||
self.num_group = num_group |
||||
self.summa_dim = summa_dim |
||||
|
||||
def init_dist_group(self): |
||||
'''Initialize 2D tensor row parallel groups, and assign local_ranks and groups to each gpu. |
||||
|
||||
:return: 2D tensor col parallelism's information |
||||
:rtype: tuple(local_rank, group_world_size, process_group, ranks_in_group, mode) |
||||
''' |
||||
local_rank = None |
||||
ranks_in_group = None |
||||
process_group = None |
||||
group_world_size = None |
||||
mode = ParallelMode.PARALLEL_2D_COL |
||||
|
||||
for i in range(self.num_group): |
||||
for j in range(self.summa_dim): |
||||
ranks = [i * self.tensor_parallel_size + j + k * self.summa_dim |
||||
for k in range(self.summa_dim)] |
||||
group = dist.new_group(ranks) |
||||
|
||||
if self.rank in ranks: |
||||
local_rank = ranks.index(self.rank) |
||||
group_world_size = len(ranks) |
||||
process_group = group |
||||
ranks_in_group = ranks |
||||
|
||||
return local_rank, group_world_size, process_group, ranks_in_group, mode |
||||
|
||||
|
||||
@DIST_GROUP_INITIALIZER.register_module |
||||
class Initializer_2D(ProcessGroupInitializer): |
||||
""" |
||||
Serve as the single entry point to 2D parallel initialization. |
||||
""" |
||||
|
||||
def __init__(self, *args, **kwargs): |
||||
super().__init__(*args, **kwargs) |
||||
self.num_group = self.world_size // self.tensor_parallel_size |
||||
self.summa_dim = int(math.sqrt(self.tensor_parallel_size)) |
||||
|
||||
assert self.tensor_parallel_size == self.summa_dim ** 2, \ |
||||
"2D summa dim should equal to tensor parallel size ^ 0.5" |
||||
_check_summa_env_var(self.summa_dim) |
||||
|
||||
self.col_initializer = Initializer_2D_Col(self.num_group, self.summa_dim, *args, **kwargs) |
||||
self.row_initializer = Initializer_2D_Row(self.num_group, self.summa_dim, *args, **kwargs) |
||||
|
||||
def init_dist_group(self): |
||||
'''Initialize 2D tensor row and col parallel groups, and assign local_ranks and groups to each gpu. |
||||
|
||||
:return: 2D tensor parallelism's information |
||||
:rtype: list of tuples (local_rank, group_world_size, process_group, ranks_in_group, mode) |
||||
''' |
||||
parallel_setting = [] |
||||
parallel_setting.append(self.row_initializer.init_dist_group()) |
||||
parallel_setting.append(self.col_initializer.init_dist_group()) |
||||
return parallel_setting |
@ -0,0 +1,255 @@
|
||||
#!/usr/bin/env python |
||||
# -*- encoding: utf-8 -*- |
||||
|
||||
import math |
||||
import os |
||||
|
||||
import torch.distributed as dist |
||||
|
||||
from colossalai.constants import TESSERACT_DIM, TESSERACT_DEP |
||||
from colossalai.context import Config |
||||
from colossalai.core import global_context as gpc |
||||
from colossalai.registry import DIST_GROUP_INITIALIZER |
||||
from .process_group_initializer import ProcessGroupInitializer |
||||
from ..parallel_mode import ParallelMode |
||||
|
||||
|
||||
def _check_tesseract_env_var(tesseract_dim: int, |
||||
tesseract_dep: int): |
||||
# check environment variable for TESSERACT |
||||
env_tesseract_dim = os.environ.get(TESSERACT_DIM, None) |
||||
env_tesseract_dep = os.environ.get(TESSERACT_DEP, None) |
||||
|
||||
if env_tesseract_dim and env_tesseract_dep: |
||||
assert int(env_tesseract_dim) == tesseract_dim, \ |
||||
'TESSERACT_DIM has been set in the current environment and ' \ |
||||
'does not match with the value passed to this initialized' |
||||
assert int(env_tesseract_dep) == tesseract_dep, \ |
||||
'TESSERACT_DEP has been set in the current environment and ' \ |
||||
'does not match with the value passed to this initialized' |
||||
else: |
||||
os.environ[TESSERACT_DIM] = str(tesseract_dim) |
||||
os.environ[TESSERACT_DEP] = str(tesseract_dep) |
||||
|
||||
|
||||
# i row j col k dep |
||||
class Initializer_2p5D_ROW(ProcessGroupInitializer): |
||||
'''2p5d tensor parallel initialization among rows. |
||||
''' |
||||
|
||||
def __init__(self, |
||||
tesseract_dim: int, |
||||
tesseract_dep: int, |
||||
*args): |
||||
super(Initializer_2p5D_ROW, self).__init__(*args) |
||||
|
||||
self.tensor_parallel_size = gpc.tensor_parallel_size |
||||
self.num_group = self.world_size // self.tensor_parallel_size |
||||
self.tesseract_dep = tesseract_dep |
||||
self.tesseract_dim = tesseract_dim |
||||
assert self.tensor_parallel_size == self.tesseract_dim ** 2 * self.tesseract_dep, \ |
||||
"Tensor parallel size should be depth * dim ** 2 in 2.5D parallel" |
||||
|
||||
def init_dist_group(self): |
||||
'''Initialize 2p5D tensor row parallel groups, and assign local_ranks and groups to each gpu. |
||||
|
||||
:return: 2p5D tensor row parallelism's information |
||||
:rtype: tuple(local_rank, group_world_size, process_group, ranks_in_group, mode) |
||||
''' |
||||
local_rank = None |
||||
ranks_in_group = None |
||||
process_group = None |
||||
group_world_size = None |
||||
mode = ParallelMode.PARALLEL_2P5D_ROW |
||||
|
||||
for h in range(self.num_group): |
||||
for j in range(self.tesseract_dim): |
||||
for k in range(self.tesseract_dep): |
||||
ranks = [h * self.tensor_parallel_size + i + self.tesseract_dim * ( |
||||
j + self.tesseract_dim * k) for i in range(self.tesseract_dim)] |
||||
group = dist.new_group(ranks) |
||||
|
||||
if self.rank in ranks: |
||||
local_rank = ranks.index(self.rank) |
||||
group_world_size = len(ranks) |
||||
process_group = group |
||||
ranks_in_group = ranks |
||||
|
||||
return local_rank, group_world_size, process_group, ranks_in_group, mode |
||||
|
||||
|
||||
class Initializer_2p5D_Col(ProcessGroupInitializer): |
||||
'''2p5d tensor parallel initialization among cols. |
||||
''' |
||||
def __init__(self, |
||||
tesseract_dim: int, |
||||
tesseract_dep: int, |
||||
*args): |
||||
super(Initializer_2p5D_Col, self).__init__(*args) |
||||
|
||||
self.tensor_parallel_size = gpc.tensor_parallel_size |
||||
self.num_group = self.world_size // self.tensor_parallel_size |
||||
self.tesseract_dep = tesseract_dep |
||||
self.tesseract_dim = tesseract_dim |
||||
assert self.tensor_parallel_size == self.tesseract_dim ** 2 * self.tesseract_dep, \ |
||||
"Tensor parallel size should be depth * dim ** 2 in 2.5D parallel" |
||||
|
||||
def init_dist_group(self): |
||||
'''Initialize 2p5D tensor col parallel groups, and assign local_ranks and groups to each gpu. |
||||
|
||||
:return: 2p5D tensor col parallelism's information |
||||
:rtype: tuple(local_rank, group_world_size, process_group, ranks_in_group, mode) |
||||
''' |
||||
local_rank = None |
||||
ranks_in_group = None |
||||
process_group = None |
||||
group_world_size = None |
||||
mode = ParallelMode.PARALLEL_2P5D_COL |
||||
|
||||
for h in range(self.num_group): |
||||
for i in range(self.tesseract_dim): |
||||
for k in range(self.tesseract_dep): |
||||
ranks = [h * self.tensor_parallel_size + i + self.tesseract_dim * ( |
||||
j + self.tesseract_dim * k) for j in range(self.tesseract_dim)] |
||||
group = dist.new_group(ranks) |
||||
|
||||
if self.rank in ranks: |
||||
local_rank = ranks.index(self.rank) |
||||
group_world_size = len(ranks) |
||||
process_group = group |
||||
ranks_in_group = ranks |
||||
|
||||
return local_rank, group_world_size, process_group, ranks_in_group, mode |
||||
|
||||
|
||||
class Initializer_2p5D_Dep(ProcessGroupInitializer): |
||||
'''2p5D tensor parallel initialization among depths. |
||||
''' |
||||
def __init__(self, |
||||
tesseract_dim: int, |
||||
tesseract_dep: int, |
||||
*args): |
||||
super(Initializer_2p5D_Dep, self).__init__(*args) |
||||
|
||||
self.tensor_parallel_size = gpc.tensor_parallel_size |
||||
self.num_group = self.world_size // self.tensor_parallel_size |
||||
self.tesseract_dep = tesseract_dep |
||||
self.tesseract_dim = tesseract_dim |
||||
assert self.tensor_parallel_size == self.tesseract_dim ** 2 * self.tesseract_dep, \ |
||||
"Tensor parallel size should be depth * dim ** 2 in 2.5D parallel" |
||||
|
||||
def init_dist_group(self): |
||||
'''Initialize 2p5D tensor depth parallel groups, and assign local_ranks and groups to each gpu. |
||||
|
||||
:return: 2p5D tensor depth parallelism's information |
||||
:rtype: tuple(local_rank, group_world_size, process_group, ranks_in_group, mode) |
||||
''' |
||||
local_rank = None |
||||
ranks_in_group = None |
||||
process_group = None |
||||
group_world_size = None |
||||
mode = ParallelMode.PARALLEL_2P5D_DEP |
||||
|
||||
for h in range(self.num_group): |
||||
for i in range(self.tesseract_dim): |
||||
for j in range(self.tesseract_dim): |
||||
ranks = [h * self.tensor_parallel_size + i + self.tesseract_dim * ( |
||||
j + self.tesseract_dim * k) for k in range(self.tesseract_dep)] |
||||
group = dist.new_group(ranks) |
||||
|
||||
if self.rank in ranks: |
||||
local_rank = ranks.index(self.rank) |
||||
group_world_size = len(ranks) |
||||
process_group = group |
||||
ranks_in_group = ranks |
||||
|
||||
return local_rank, group_world_size, process_group, ranks_in_group, mode |
||||
|
||||
|
||||
# i row j col k dep |
||||
class Initializer_2p5D_XZ(ProcessGroupInitializer): |
||||
'''2p5d tensor parallel initialization among cols times dep. |
||||
''' |
||||
def __init__(self, |
||||
tesseract_dim: int, |
||||
tesseract_dep: int, |
||||
*args): |
||||
super(Initializer_2p5D_XZ, self).__init__(*args) |
||||
|
||||
self.tensor_parallel_size = gpc.tensor_parallel_size |
||||
self.num_group = self.world_size // self.tensor_parallel_size |
||||
self.tesseract_dep = tesseract_dep |
||||
self.tesseract_dim = tesseract_dim |
||||
assert self.tensor_parallel_size == self.tesseract_dim ** 2 * self.tesseract_dep, \ |
||||
"Tensor parallel size should be depth * dim ** 2 in 2.5D parallel" |
||||
|
||||
def init_dist_group(self): |
||||
'''Initialize 2p5D tensor colXdepth parallel groups, and assign local_ranks and groups to each gpu. |
||||
|
||||
:return: 2p5D tensor colXdepth parallelism's information |
||||
:rtype: tuple(local_rank, group_world_size, process_group, ranks_in_group, mode) |
||||
''' |
||||
local_rank = None |
||||
ranks_in_group = None |
||||
process_group = None |
||||
group_world_size = None |
||||
mode = ParallelMode.PARALLEL_2P5D_XZ |
||||
|
||||
for h in range(self.num_group): |
||||
for i in range(self.tesseract_dim): |
||||
ranks = [h * self.tensor_parallel_size + i + self.tesseract_dim * ( |
||||
j + self.tesseract_dim * k) for k in range(self.tesseract_dep) for j in |
||||
range(self.tesseract_dim)] |
||||
group = dist.new_group(ranks) |
||||
|
||||
if self.rank in ranks: |
||||
local_rank = ranks.index(self.rank) |
||||
group_world_size = len(ranks) |
||||
process_group = group |
||||
ranks_in_group = ranks |
||||
|
||||
return local_rank, group_world_size, process_group, ranks_in_group, mode |
||||
|
||||
|
||||
@DIST_GROUP_INITIALIZER.register_module |
||||
class Initializer_2p5D(ProcessGroupInitializer): |
||||
""" |
||||
Serve as the single entry point to Tesseract parallel initialization. |
||||
""" |
||||
|
||||
def __init__(self, |
||||
rank: int, |
||||
world_size: int, |
||||
config: Config, |
||||
data_parallel_size: int, |
||||
pipeline_parlalel_size: int, |
||||
tensor_parallel_size: int, |
||||
depth: int |
||||
): |
||||
args = (rank, world_size, config, data_parallel_size, pipeline_parlalel_size, tensor_parallel_size) |
||||
super().__init__(*args) |
||||
self.num_group = self.world_size // self.tensor_parallel_size |
||||
self.tesseract_dim = int(math.sqrt(self.tensor_parallel_size / depth)) |
||||
self.tesseract_dep = depth |
||||
|
||||
assert self.tensor_parallel_size == self.tesseract_dim ** 2 * self.tesseract_dep, \ |
||||
"2.5D tesseract dim should equal to (tensor parallel size / tesseract dep) ^ 0.5" |
||||
_check_tesseract_env_var(self.tesseract_dim, self.tesseract_dep) |
||||
|
||||
self.col_initializer = Initializer_2p5D_Col(self.tesseract_dim, self.tesseract_dep, *args) |
||||
self.row_initializer = Initializer_2p5D_ROW(self.tesseract_dim, self.tesseract_dep, *args) |
||||
self.dep_initializer = Initializer_2p5D_Dep(self.tesseract_dim, self.tesseract_dep, *args) |
||||
self.xz_initializer = Initializer_2p5D_XZ(self.tesseract_dim, self.tesseract_dep, *args) |
||||
|
||||
def init_dist_group(self): |
||||
'''Initialize 2p5D tensor row, col, depth, and colXdepth parallel groups, and assign local_ranks and groups to each gpu. |
||||
|
||||
:return: Whole 2p5D tensor parallelism's information |
||||
:rtype: list of tuples (local_rank, group_world_size, process_group, ranks_in_group, mode) |
||||
''' |
||||
parallel_setting = [] |
||||
parallel_setting.append(self.col_initializer.init_dist_group()) |
||||
parallel_setting.append(self.row_initializer.init_dist_group()) |
||||
parallel_setting.append(self.dep_initializer.init_dist_group()) |
||||
parallel_setting.append(self.xz_initializer.init_dist_group()) |
||||
return parallel_setting |
@ -0,0 +1,172 @@
|
||||
#!/usr/bin/env python |
||||
# -*- encoding: utf-8 -*- |
||||
|
||||
import math |
||||
import os |
||||
|
||||
import torch.distributed as dist |
||||
from colossalai.constants import DEPTH_3D |
||||
from colossalai.registry import DIST_GROUP_INITIALIZER |
||||
|
||||
from ..parallel_mode import ParallelMode |
||||
from .process_group_initializer import ProcessGroupInitializer |
||||
|
||||
|
||||
def _check_depth_env_var(depth): |
||||
# check environment variable for SUMMA |
||||
env_depth = os.environ.get(DEPTH_3D, None) |
||||
|
||||
if env_depth: |
||||
assert int(env_depth) == depth, \ |
||||
'SUMMA_DIM has been set in the current environment and ' \ |
||||
'does not match with the value passed to this initialized' |
||||
else: |
||||
os.environ[DEPTH_3D] = str(depth) |
||||
|
||||
|
||||
class Initializer_3D_Input(ProcessGroupInitializer): |
||||
'''2D tensor parallel initialization among input. |
||||
''' |
||||
def __init__(self, num_group: int, depth: int, *args): |
||||
super().__init__(*args) |
||||
self.num_group = num_group |
||||
self.depth = depth |
||||
|
||||
def init_dist_group(self): |
||||
'''Initialize 3D tensor parallel groups among input, and assign local_ranks and groups to each gpu. |
||||
|
||||
:return: 3D tensor parallelism's information among input |
||||
:rtype: tuple(local_rank, group_world_size, process_group, ranks_in_group, mode) |
||||
''' |
||||
local_rank = None |
||||
ranks_in_group = None |
||||
process_group = None |
||||
group_world_size = None |
||||
mode = ParallelMode.PARALLEL_3D_INPUT |
||||
|
||||
for h in range(self.num_group): |
||||
for i in range(self.depth): |
||||
for k in range(self.depth): |
||||
ranks = [ |
||||
h * self.depth**3 + i + self.depth * |
||||
(j + self.depth * k) for j in range(self.depth) |
||||
] |
||||
group = dist.new_group(ranks) |
||||
|
||||
if self.rank in ranks: |
||||
local_rank = ranks.index(self.rank) |
||||
group_world_size = len(ranks) |
||||
process_group = group |
||||
ranks_in_group = ranks |
||||
|
||||
return local_rank, group_world_size, process_group, ranks_in_group, mode |
||||
|
||||
|
||||
class Initializer_3D_Weight(ProcessGroupInitializer): |
||||
'''3D tensor parallel initialization among weight. |
||||
''' |
||||
|
||||
def __init__(self, num_group: int, depth: int, *args): |
||||
super().__init__(*args) |
||||
self.num_group = num_group |
||||
self.depth = depth |
||||
|
||||
def init_dist_group(self): |
||||
'''Initialize 3D tensor parallel groups among weight, and assign local_ranks and groups to each gpu. |
||||
|
||||
:return: 3D tensor parallelism's information among weight |
||||
:rtype: tuple(local_rank, group_world_size, process_group, ranks_in_group, mode) |
||||
''' |
||||
local_rank = None |
||||
ranks_in_group = None |
||||
process_group = None |
||||
group_world_size = None |
||||
mode = ParallelMode.PARALLEL_3D_WEIGHT |
||||
|
||||
for h in range(self.num_group): |
||||
for k in range(self.depth): |
||||
for j in range(self.depth): |
||||
ranks = [ |
||||
h * self.depth**3 + i + self.depth * |
||||
(j + self.depth * k) for i in range(self.depth) |
||||
] |
||||
group = dist.new_group(ranks) |
||||
|
||||
if self.rank in ranks: |
||||
local_rank = ranks.index(self.rank) |
||||
group_world_size = len(ranks) |
||||
process_group = group |
||||
ranks_in_group = ranks |
||||
|
||||
return local_rank, group_world_size, process_group, ranks_in_group, mode |
||||
|
||||
|
||||
class Initializer_3D_Output(ProcessGroupInitializer): |
||||
'''2D tensor parallel initialization among weight. |
||||
''' |
||||
|
||||
def __init__(self, num_group: int, depth: int, *args): |
||||
super().__init__(*args) |
||||
self.num_group = num_group |
||||
self.depth = depth |
||||
|
||||
def init_dist_group(self): |
||||
'''Initialize 3D tensor parallel groups among output, and assign local_ranks and groups to each gpu. |
||||
|
||||
:return: 3D tensor parallelism's information among output |
||||
:rtype: tuple(local_rank, group_world_size, process_group, ranks_in_group, mode) |
||||
''' |
||||
local_rank = None |
||||
ranks_in_group = None |
||||
process_group = None |
||||
group_world_size = None |
||||
mode = ParallelMode.PARALLEL_3D_OUTPUT |
||||
|
||||
for h in range(self.num_group): |
||||
for i in range(self.depth): |
||||
for j in range(self.depth): |
||||
ranks = [ |
||||
h * self.depth**3 + i + self.depth * |
||||
(j + self.depth * k) for k in range(self.depth) |
||||
] |
||||
group = dist.new_group(ranks) |
||||
|
||||
if self.rank in ranks: |
||||
local_rank = ranks.index(self.rank) |
||||
group_world_size = len(ranks) |
||||
process_group = group |
||||
ranks_in_group = ranks |
||||
|
||||
return local_rank, group_world_size, process_group, ranks_in_group, mode |
||||
|
||||
|
||||
@DIST_GROUP_INITIALIZER.register_module |
||||
class Initializer_3D(ProcessGroupInitializer): |
||||
'''Serve as the single entry point to 3D parallel initialization. |
||||
''' |
||||
def __init__(self, *args): |
||||
super().__init__(*args) |
||||
self.num_group = self.world_size // self.tensor_parallel_size |
||||
self.depth = round(math.pow(self.tensor_parallel_size, 1 / 3)) |
||||
assert self.tensor_parallel_size == self.depth ** 3, \ |
||||
f'3D depth ({self.depth}) if not cube root of tensor parallel size ({self.tensor_parallel_size})' |
||||
_check_depth_env_var(self.depth) |
||||
|
||||
self.input_initializer = Initializer_3D_Input(self.num_group, |
||||
self.depth, *args) |
||||
self.weight_initializer = Initializer_3D_Weight( |
||||
self.num_group, self.depth, *args) |
||||
self.output_initializer = Initializer_3D_Output( |
||||
self.num_group, self.depth, *args) |
||||
|
||||
def init_dist_group(self): |
||||
'''Initialize 3D tensor parallel groups, and assign local_ranks and groups to each gpu. |
||||
|
||||
:return: 3D tensor parallelism's information |
||||
:rtype: list of tuples (local_rank, group_world_size, process_group, ranks_in_group, mode) |
||||
''' |
||||
parallel_setting = [] |
||||
parallel_setting.append(self.input_initializer.init_dist_group()) |
||||
parallel_setting.append(self.weight_initializer.init_dist_group()) |
||||
parallel_setting.append(self.output_initializer.init_dist_group()) |
||||
return parallel_setting |
@ -0,0 +1,41 @@
|
||||
#!/usr/bin/env python |
||||
# -*- encoding: utf-8 -*- |
||||
|
||||
from torch import distributed as dist |
||||
|
||||
from colossalai.registry import DIST_GROUP_INITIALIZER |
||||
from .process_group_initializer import ProcessGroupInitializer |
||||
from ..parallel_mode import ParallelMode |
||||
|
||||
|
||||
@DIST_GROUP_INITIALIZER.register_module |
||||
class Initializer_Data(ProcessGroupInitializer): |
||||
'''A ProcessGroupInitializer for data parallelism. |
||||
''' |
||||
def __init__(self, *args, **kwargs): |
||||
super().__init__(*args, **kwargs) |
||||
self.num_data_parallel_group = self.world_size // self.data_parallel_size |
||||
|
||||
def init_dist_group(self): |
||||
'''Initialize data parallel groups, and assign local_ranks and groups to each gpu. |
||||
|
||||
:return: data parallelism's information |
||||
:rtype: tuple (local_rank, group_world_size, process_group, ranks_in_group, mode) |
||||
''' |
||||
local_rank = None |
||||
ranks_in_group = None |
||||
process_group = None |
||||
group_world_size = None |
||||
mode = ParallelMode.DATA |
||||
|
||||
for i in range(self.num_data_parallel_group): |
||||
ranks = [i + j * self.num_data_parallel_group for j in range(self.data_parallel_size)] |
||||
group = dist.new_group(ranks) |
||||
|
||||
if self.rank in ranks: |
||||
local_rank = ranks.index(self.rank) |
||||
group_world_size = len(ranks) |
||||
process_group = group |
||||
ranks_in_group = ranks |
||||
|
||||
return local_rank, group_world_size, process_group, ranks_in_group, mode |
@ -0,0 +1,63 @@
|
||||
#!/usr/bin/env python |
||||
# -*- encoding: utf-8 -*- |
||||
|
||||
from torch import distributed as dist |
||||
|
||||
from colossalai.registry import DIST_GROUP_INITIALIZER |
||||
from .process_group_initializer import ProcessGroupInitializer |
||||
from ..parallel_mode import ParallelMode |
||||
|
||||
|
||||
@DIST_GROUP_INITIALIZER.register_module |
||||
class Initializer_Pipeline(ProcessGroupInitializer): |
||||
def __init__(self, *args, **kwargs): |
||||
super().__init__(*args, **kwargs) |
||||
self.data_group_size = self.world_size // self.data_parallel_size |
||||
self.pipeline_stage_size = self.data_group_size // self.pipeline_parallel_size |
||||
|
||||
def init_dist_group(self): |
||||
dist_settings = list() |
||||
for i in range(self.data_parallel_size): |
||||
for j in range(self.pipeline_stage_size): |
||||
pipe_ranks = list( |
||||
range(i * self.data_group_size + j, |
||||
(i + 1) * self.data_group_size, |
||||
self.pipeline_stage_size)) |
||||
pipe_group_size = len(pipe_ranks) |
||||
pipe_group = dist.new_group(pipe_ranks) |
||||
|
||||
if self.rank in pipe_ranks: |
||||
local_rank = pipe_ranks.index(self.rank) |
||||
group_world_size = pipe_group_size |
||||
process_group = pipe_group |
||||
ranks_in_group = pipe_ranks |
||||
dist_settings.append( |
||||
tuple((local_rank, group_world_size, |
||||
process_group, ranks_in_group, |
||||
ParallelMode.PIPELINE))) |
||||
|
||||
for k in range(pipe_group_size): |
||||
first = pipe_ranks[k] |
||||
second = pipe_ranks[(k + 1) % pipe_group_size] |
||||
ranks = [first, second] |
||||
group = dist.new_group(ranks) |
||||
if self.rank == first: |
||||
local_rank = 0 |
||||
group_world_size = 2 |
||||
process_group = group |
||||
ranks_in_group = ranks |
||||
dist_settings.append( |
||||
tuple((local_rank, group_world_size, |
||||
process_group, ranks_in_group, |
||||
ParallelMode.PIPELINE_NEXT))) |
||||
elif self.rank == second: |
||||
local_rank = 1 |
||||
group_world_size = 2 |
||||
process_group = group |
||||
ranks_in_group = ranks |
||||
dist_settings.append( |
||||
tuple((local_rank, group_world_size, |
||||
process_group, ranks_in_group, |
||||
ParallelMode.PIPELINE_PREV))) |
||||
|
||||
return dist_settings |
@ -0,0 +1,27 @@
|
||||
#!/usr/bin/env python |
||||
# -*- encoding: utf-8 -*- |
||||
|
||||
from colossalai.registry import DIST_GROUP_INITIALIZER |
||||
from .initializer_tensor import Initializer_Tensor |
||||
from .process_group_initializer import ProcessGroupInitializer |
||||
from ..parallel_mode import ParallelMode |
||||
|
||||
|
||||
@DIST_GROUP_INITIALIZER.register_module |
||||
class Initializer_Sequence(ProcessGroupInitializer): |
||||
'''A ProcessGroupInitializer for sequence parallelism. |
||||
''' |
||||
|
||||
def __init__(self, |
||||
*args, **kwargs): |
||||
super().__init__(*args, **kwargs) |
||||
# reuse tensor parallel code |
||||
self._initializer = Initializer_Tensor(*args, **kwargs) |
||||
|
||||
def init_dist_group(self): |
||||
local_rank, group_world_size, process_group, ranks_in_group, mode = self._initializer.init_dist_group() |
||||
|
||||
# change mode to sequence |
||||
mode = ParallelMode.SEQUENCE |
||||
|
||||
return local_rank, group_world_size, process_group, ranks_in_group, mode |
@ -0,0 +1,41 @@
|
||||
#!/usr/bin/env python |
||||
# -*- encoding: utf-8 -*- |
||||
|
||||
import torch.distributed as dist |
||||
|
||||
from colossalai.registry import DIST_GROUP_INITIALIZER |
||||
from .process_group_initializer import ProcessGroupInitializer |
||||
from ..parallel_mode import ParallelMode |
||||
|
||||
|
||||
@DIST_GROUP_INITIALIZER.register_module |
||||
class Initializer_Tensor(ProcessGroupInitializer): |
||||
'''A ProcessGroupInitializer for tensor parallelism. |
||||
''' |
||||
def __init__(self, *args, **kwargs): |
||||
super().__init__(*args, **kwargs) |
||||
self.num_tensor_parallel_group = self.world_size // self.tensor_parallel_size |
||||
|
||||
def init_dist_group(self): |
||||
'''Initialize tensor parallel groups, and assign local_ranks and groups to each gpu. |
||||
|
||||
:return: tensor parallelism's information |
||||
:rtype: tuple(local_rank, group_world_size, process_group, ranks_in_group, mode) |
||||
''' |
||||
local_rank = None |
||||
ranks_in_group = None |
||||
process_group = None |
||||
group_world_size = None |
||||
mode = ParallelMode.TENSOR |
||||
|
||||
for i in range(self.num_tensor_parallel_group): |
||||
ranks = [i * self.tensor_parallel_size + j for j in range(self.tensor_parallel_size)] |
||||
group = dist.new_group(ranks) |
||||
|
||||
if self.rank in ranks: |
||||
local_rank = ranks.index(self.rank) |
||||
group_world_size = len(ranks) |
||||
process_group = group |
||||
ranks_in_group = ranks |
||||
|
||||
return local_rank, group_world_size, process_group, ranks_in_group, mode |
@ -0,0 +1,30 @@
|
||||
#!/usr/bin/env python |
||||
# -*- encoding: utf-8 -*- |
||||
|
||||
from abc import ABC, abstractmethod |
||||
|
||||
from colossalai.context import Config |
||||
|
||||
|
||||
class ProcessGroupInitializer(ABC): |
||||
'''An object, knowing the parallelism configuration, that initializes parallel groups. |
||||
''' |
||||
def __init__(self, |
||||
rank: int, |
||||
world_size: int, |
||||
config: Config, |
||||
data_parallel_size: int, |
||||
pipeline_parlalel_size: int, |
||||
tensor_parallel_size: int |
||||
): |
||||
self.rank = rank |
||||
self.world_size = world_size |
||||
self.data_parallel_size = data_parallel_size |
||||
self.config = config |
||||
self.pipeline_parallel_size = pipeline_parlalel_size |
||||
self.tensor_parallel_size = tensor_parallel_size |
||||
super().__init__() |
||||
|
||||
@abstractmethod |
||||
def init_dist_group(self): |
||||
pass |
@ -0,0 +1,8 @@
|
||||
from ._helper import (seed, set_mode, with_seed, add_seed, |
||||
get_seeds, get_states, get_current_mode, |
||||
set_seed_states, sync_states) |
||||
|
||||
__all__ = [ |
||||
'seed', 'set_mode', 'with_seed', 'add_seed', 'get_seeds', |
||||
'get_states', 'get_current_mode', 'set_seed_states', 'sync_states' |
||||
] |
@ -0,0 +1,144 @@
|
||||
#!/usr/bin/env python |
||||
# -*- encoding: utf-8 -*- |
||||
|
||||
import functools |
||||
from contextlib import contextmanager |
||||
|
||||
import torch.cuda |
||||
from torch import Tensor |
||||
|
||||
from .seed_manager import SeedManager |
||||
from ..parallel_mode import ParallelMode |
||||
|
||||
_SEED_MANAGER = SeedManager() |
||||
|
||||
|
||||
def get_seeds(): |
||||
"""Returns the seeds of the seed manager. |
||||
|
||||
:return: The seeds of the seed manager |
||||
:rtype: dict |
||||
""" |
||||
return _SEED_MANAGER.seeds |
||||
|
||||
|
||||
def get_states(copy=False): |
||||
"""Returns the seed states of the seed manager. |
||||
|
||||
:return: The seed states of the seed manager |
||||
:rtype: dict |
||||
""" |
||||
states = _SEED_MANAGER.seed_states |
||||
|
||||
if copy: |
||||
new_states = dict() |
||||
|
||||
for parallel_mode, state in states.items(): |
||||
new_states[parallel_mode] = state.clone() |
||||
return new_states |
||||
else: |
||||
return _SEED_MANAGER.seed_states |
||||
|
||||
|
||||
def get_current_mode(): |
||||
"""Returns the current mode of the seed manager. |
||||
|
||||
:return: The current mode of the seed manager. |
||||
:rtype: :class:`torch.ByteTensor` |
||||
""" |
||||
return _SEED_MANAGER.current_mode |
||||
|
||||
|
||||
def add_seed(parallel_mode: ParallelMode, seed: int): |
||||
"""Adds a seed to the seed manager for `parallel_mode`. |
||||
|
||||
:param parallel_mode: The chosen parallel mode |
||||
:type parallel_mode: :class:`colossalai.context.ParallelMode` |
||||
:param seed: The seed to be added |
||||
:type seed: int |
||||
:raises AssertionError: Raises an AssertionError if `parallel_mode` is not an instance of |
||||
:class:`colossalai.context.ParallelMode` or the seed for `parallel_mode` has been added |
||||
""" |
||||
_SEED_MANAGER.add_seed(parallel_mode, seed) |
||||
|
||||
|
||||
def set_mode(parallel_mode: ParallelMode): |
||||
"""Sets the current mode of the seed manager. |
||||
|
||||
:param parallel_mode: The chosen parallel mode |
||||
:type parallel_mode: :class:`colossalai.context.ParallelMode` |
||||
""" |
||||
_SEED_MANAGER.set_mode(parallel_mode) |
||||
|
||||
|
||||
def set_seed_states(parallel_mode: ParallelMode, state: Tensor): |
||||
"""Sets the state of the seed manager for `parallel_mode`. |
||||
|
||||
:param parallel_mode: The chosen parallel mode |
||||
:type parallel_mode: :class:`colossalai.context.ParallelMode` |
||||
:param state: the state to be set |
||||
:type state: :class:`torch.Tensor` |
||||
:raises AssertionError: Raises an AssertionError if `parallel_mode` is not found in the seed manager |
||||
""" |
||||
_SEED_MANAGER.set_state(parallel_mode, state) |
||||
|
||||
|
||||
def sync_states(): |
||||
current_mode = get_current_mode() |
||||
current_states = torch.cuda.get_rng_state() |
||||
set_seed_states(current_mode, current_states) |
||||
|
||||
|
||||
@contextmanager |
||||
def seed(parallel_mode: ParallelMode): |
||||
""" A context for seed switch |
||||
|
||||
Examples:: |
||||
|
||||
with seed(ParallelMode.DATA): |
||||
output = F.dropout(input) |
||||
|
||||
""" |
||||
try: |
||||
# set to new mode |
||||
current_mode = _SEED_MANAGER.current_mode |
||||
yield _SEED_MANAGER.set_mode(parallel_mode) |
||||
finally: |
||||
# recover |
||||
_SEED_MANAGER.set_mode(current_mode) |
||||
|
||||
|
||||
def with_seed(func, parallel_mode: ParallelMode): |
||||
""" |
||||
A function wrapper which executes the function with a specified seed. |
||||
|
||||
Examples:: |
||||
|
||||
# use with decorator |
||||
@with_seed(ParallelMode.DATA) |
||||
def forward(input): |
||||
return F.dropout(input) |
||||
out = forward(input) |
||||
# OR use it inline |
||||
def forward(input): |
||||
return F.dropout(input) |
||||
wrapper_forward = with_seed(forward, ParallelMode.DATA) |
||||
out = wrapped_forward(input) |
||||
|
||||
""" |
||||
|
||||
@functools.wraps(func) |
||||
def wrapper(*args, **kwargs): |
||||
# switch mode |
||||
current_mode = _SEED_MANAGER.current_mode |
||||
_SEED_MANAGER.set_mode(parallel_mode) |
||||
|
||||
# exec func |
||||
out = func(*args, **kwargs) |
||||
|
||||
# recover state |
||||
_SEED_MANAGER.set_mode(current_mode) |
||||
|
||||
return out |
||||
|
||||
return wrapper |
@ -0,0 +1,74 @@
|
||||
#!/usr/bin/env python |
||||
# -*- encoding: utf-8 -*- |
||||
|
||||
import torch |
||||
from torch import Tensor |
||||
|
||||
from colossalai.context.parallel_mode import ParallelMode |
||||
|
||||
|
||||
class SeedManager: |
||||
"""This class is a manager of all random seeds involved in the system. |
||||
""" |
||||
|
||||
def __init__(self): |
||||
self._current_mode = None |
||||
self._seeds = dict() |
||||
self._seed_states = dict() |
||||
|
||||
@property |
||||
def current_mode(self): |
||||
return self._current_mode |
||||
|
||||
@property |
||||
def seeds(self): |
||||
return self._seeds |
||||
|
||||
@property |
||||
def seed_states(self): |
||||
return self._seed_states |
||||
|
||||
def set_state(self, parallel_mode: ParallelMode, state: Tensor): |
||||
"""Sets the state of the seed manager for `parallel_mode`. |
||||
|
||||
:param parallel_mode: The chosen parallel mode |
||||
:type parallel_mode: :class:`colossalai.context.ParallelMode` |
||||
:param state: the state to be set |
||||
:type state: :class:`torch.Tensor` |
||||
:raises AssertionError: Raises an AssertionError if `parallel_mode` is not found in the seed manager |
||||
""" |
||||
assert parallel_mode in self._seed_states, f'Parallel mode {parallel_mode} is not found in the seed manager' |
||||
self._seed_states[parallel_mode] = state |
||||
|
||||
def set_mode(self, parallel_mode: ParallelMode): |
||||
"""Sets the current mode of the seed manager. |
||||
|
||||
:param parallel_mode: The chosen parallel mode |
||||
:type parallel_mode: :class:`colossalai.context.ParallelMode` |
||||
""" |
||||
if self.current_mode: |
||||
# save the current state for current mode |
||||
self._seed_states[self._current_mode] = torch.cuda.get_rng_state() |
||||
|
||||
# set the new state for new mode |
||||
self._current_mode = parallel_mode |
||||
torch.cuda.set_rng_state(self._seed_states[parallel_mode]) |
||||
|
||||
def add_seed(self, parallel_mode: ParallelMode, seed: int): |
||||
"""Adds a seed to the seed manager for `parallel_mode`. |
||||
|
||||
:param parallel_mode: The chosen parallel mode |
||||
:type parallel_mode: :class:`colossalai.context.ParallelMode` |
||||
:param seed: The seed to be added |
||||
:type seed: int |
||||
:raises AssertionError: Raises an AssertionError if `parallel_mode` is not an instance of |
||||
:class:`colossalai.context.ParallelMode` or the seed for `parallel_mode` has been added |
||||
""" |
||||
assert isinstance( |
||||
parallel_mode, ParallelMode), 'A valid ParallelMode must be provided' |
||||
assert parallel_mode not in self._seed_states, f'The seed for {parallel_mode} has been added' |
||||
current_state = torch.cuda.get_rng_state() |
||||
torch.cuda.manual_seed(seed) |
||||
self._seed_states[parallel_mode] = torch.cuda.get_rng_state() |
||||
self._seeds[parallel_mode] = seed |
||||
torch.cuda.set_rng_state(current_state) |
@ -0,0 +1,16 @@
|
||||
#!/usr/bin/env python |
||||
# -*- encoding: utf-8 -*- |
||||
|
||||
from colossalai.context import ParallelContext |
||||
|
||||
global_context = ParallelContext() |
||||
|
||||
|
||||
def set_global_context(context: ParallelContext): |
||||
'''Reset global context to be identical to a given :class:ParallelContext. |
||||
|
||||
:param context: Parallel context to generate our global parallel context. |
||||
:type context: ParallelContext |
||||
''' |
||||
global global_context |
||||
global_context = context |
@ -0,0 +1,7 @@
|
||||
from .amp_type import AMP_TYPE |
||||
from ._base_engine import Engine |
||||
from .gradient_handler import * |
||||
from .schedule import * |
||||
|
||||
|
||||
__all__ = ['Engine'] |
@ -0,0 +1,170 @@
|
||||
#!/usr/bin/env python |
||||
# -*- encoding: utf-8 -*- |
||||
|
||||
from typing import Optional |
||||
|
||||
from colossalai.builder import build_gradient_handler |
||||
from colossalai.context import ParallelMode |
||||
from colossalai.core import global_context as gpc |
||||
from colossalai.logging import get_global_dist_logger |
||||
from colossalai.nn import (ZeroRedundancyOptimizer_Level_2, |
||||
ZeroRedundancyOptimizer_Level_3) |
||||
from torch.nn import Module |
||||
from torch.nn.modules.loss import _Loss |
||||
from torch.optim import Optimizer |
||||
from torch.optim.lr_scheduler import _LRScheduler |
||||
from torch.utils.data import DataLoader |
||||
|
||||
from .schedule import BaseSchedule, NoPipelineSchedule |
||||
|
||||
|
||||
class Engine: |
||||
"""Basic engine class for training and evaluation. It runs a specific process method |
||||
:meth:`step` which is based on the given :attr:`schedule` over each batch of a dataset. |
||||
|
||||
:param train_dataloader: Dataloader in training |
||||
:param test_dataloader: Dataloader in evaluation |
||||
:param model: The neural network model |
||||
:param criterion: Criterion for calculating loss |
||||
:param optimizer: Optimizer for updating the parameters |
||||
:param lr_scheduler: Learning rate scheduler ajusting learning rate during the training or evaluation |
||||
:param schedule: Running schedule in :meth:`step` |
||||
:type train_dataloader: DataLoader, optional |
||||
:type test_dataloader: DataLoader, optional |
||||
:type model: Module |
||||
:type criterion: _Loss, optional |
||||
:type optimizer: Optimizer, optional |
||||
:type lr_scheduler: _LRScheduler, optional |
||||
:type schedule: BaseSchedule, optional |
||||
""" |
||||
def __init__(self, |
||||
train_dataloader: Optional[DataLoader] = None, |
||||
test_dataloader: Optional[DataLoader] = None, |
||||
model: Module = None, |
||||
criterion: _Loss = None, |
||||
optimizer: Optimizer = None, |
||||
lr_scheduler: Optional[_LRScheduler] = None, |
||||
schedule: BaseSchedule = None): |
||||
self.train_dataloader = train_dataloader |
||||
self.test_dataloader = test_dataloader |
||||
assert model is not None, "Engine requires a model" |
||||
self.model = model |
||||
self.criterion = criterion |
||||
self.optimizer = optimizer |
||||
self.lr_scheduler = lr_scheduler |
||||
self.schedule = schedule if schedule is not None \ |
||||
else NoPipelineSchedule() |
||||
self._logger = get_global_dist_logger() |
||||
|
||||
# build gradient handler |
||||
self._gradient_handlers = [] |
||||
gradient_handler_cfg = [] |
||||
|
||||
if hasattr(gpc.config, 'gradient_handler'): |
||||
assert isinstance(gpc.config.gradient_handler, list), \ |
||||
f'argument gradient_handler_cfg expected type list, ' \ |
||||
f'but got type {type(gpc.config.gradient_handler)}' |
||||
gradient_handler_cfg = gpc.config.gradient_handler |
||||
elif isinstance(self.optimizer, (ZeroRedundancyOptimizer_Level_2, |
||||
ZeroRedundancyOptimizer_Level_3)): |
||||
gradient_handler_cfg = [dict(type='ZeROGradientHandler')] |
||||
self._logger.info( |
||||
"Training with zero is detected, ZeROGradientHandler is automatically " |
||||
"added even though not specified in the configuration", |
||||
ranks=[0]) |
||||
elif gpc.is_initialized(ParallelMode.DATA) and gpc.get_world_size( |
||||
ParallelMode.DATA) > 1: |
||||
gradient_handler_cfg = [dict(type='DataParallelGradientHandler')] |
||||
self._logger.info( |
||||
"Data parallel training is detected, DataParallelGradientHandler is automatically " |
||||
"added even though not specified in the configuration", |
||||
ranks=[0]) |
||||
if len(gradient_handler_cfg) == 0: |
||||
self._logger.warning( |
||||
"No gradient handler is set up, please make sure you do not need " |
||||
"to all-reduce the gradients after a training step.", |
||||
ranks=[0]) |
||||
for cfg in gradient_handler_cfg: |
||||
handler = build_gradient_handler(cfg, self.model, self.optimizer) |
||||
self._gradient_handlers.append(handler) |
||||
|
||||
self.schedule.initialize(self.train_dataloader, self.model, |
||||
self.criterion, self.optimizer, |
||||
self.lr_scheduler) |
||||
self.forward_only = False |
||||
|
||||
def handle_gradient(self): |
||||
"""Handles all-reduce operations of gradients across different parallel groups. |
||||
""" |
||||
for handler in self._gradient_handlers: |
||||
handler.handle_gradient() |
||||
|
||||
def set_dataloader(self, data: DataLoader, train: bool = True): |
||||
"""Sets dataloader in training or evaluation. |
||||
|
||||
:param data: Dataloader to be set |
||||
:param train: Set training dataloader if True, otherwise evaluation dataloader |
||||
:type data: DataLoader |
||||
:type train: bool |
||||
""" |
||||
if train: |
||||
self.train_dataloader = data |
||||
else: |
||||
self.test_dataloader = data |
||||
|
||||
def get_model(self): |
||||
"""Returns the neural network model in the engine. |
||||
""" |
||||
return self.model |
||||
def get_optimizer(self): |
||||
"""Returns optimizier in the engine. |
||||
""" |
||||
return self.optimizer |
||||
|
||||
def get_lr_scheduler(self): |
||||
"""Returns the learning rate scheduler in the engine. |
||||
""" |
||||
return self.lr_scheduler |
||||
|
||||
def train(self): |
||||
"""Sets the model to training mode. |
||||
""" |
||||
self.forward_only = False |
||||
self.schedule.train(dataloader=self.train_dataloader, mode=True) |
||||
|
||||
def eval(self): |
||||
"""Sets the model to evaluation mode. |
||||
""" |
||||
self.forward_only = True |
||||
self.schedule.train(dataloader=self.test_dataloader, mode=False) |
||||
|
||||
def is_train(self): |
||||
"""Returns True if it is in training, otherwise False. |
||||
""" |
||||
return not self.forward_only |
||||
|
||||
def get_lr(self): |
||||
"""Gets current learning rate. |
||||
""" |
||||
return self.schedule.get_lr() |
||||
|
||||
def step(self, return_loss=True): |
||||
"""A running step based on the schedule. Usually, it runs a training or |
||||
evaluation over a batch of dataset. |
||||
|
||||
:param return_loss: loss will be returned if True |
||||
:type return_loss: bool |
||||
:return: (output, lablel, loss) |
||||
""" |
||||
self.schedule.zero_grad(forward_only=self.forward_only) |
||||
|
||||
output, label, loss = self.schedule.forward_backward_step( |
||||
forward_only=self.forward_only, return_loss=return_loss) |
||||
|
||||
if not self.forward_only: |
||||
# all reduce gradients |
||||
self.handle_gradient() |
||||
|
||||
self.schedule.step() |
||||
|
||||
return output, label, loss |
@ -0,0 +1,10 @@
|
||||
#!/usr/bin/env python |
||||
# -*- encoding: utf-8 -*- |
||||
|
||||
from enum import Enum |
||||
|
||||
|
||||
class AMP_TYPE(Enum): |
||||
APEX = 'apex' |
||||
TORCH = 'torch' |
||||
PARALLEL = 'parallel' |
@ -0,0 +1,5 @@
|
||||
from ._base_gradient_handler import BaseGradientHandler |
||||
from ._data_parallel_gradient_handler import DataParallelGradientHandler |
||||
from ._zero_gradient_handler import ZeROGradientHandler |
||||
|
||||
__all__ = ['BaseGradientHandler', 'DataParallelGradientHandler', 'ZeROGradientHandler'] |
@ -0,0 +1,25 @@
|
||||
#!/usr/bin/env python |
||||
# -*- encoding: utf-8 -*- |
||||
|
||||
from abc import ABC, abstractmethod |
||||
|
||||
|
||||
class BaseGradientHandler(ABC): |
||||
"""A basic helper class to handle all-reduce operations of gradients across different parallel groups |
||||
before optimization. |
||||
|
||||
:param model: Model where the gradients accumulate |
||||
:param optimizer: Optimizer for updating the parameters |
||||
:type model: Module |
||||
:type optimizer: Optimizer |
||||
""" |
||||
def __init__(self, model, optimizer): |
||||
self._model = model |
||||
self._optimizer = optimizer |
||||
|
||||
@abstractmethod |
||||
def handle_gradient(self): |
||||
"""A method to accumulate gradients across different parallel groups. Users should |
||||
write their own functions or just use the functions in pre-defined subclasses. |
||||
""" |
||||
pass |
@ -0,0 +1,48 @@
|
||||
#!/usr/bin/env python |
||||
|
||||
import torch.distributed as dist |
||||
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors |
||||
|
||||
from colossalai.core import global_context as gpc |
||||
from colossalai.registry import GRADIENT_HANDLER |
||||
from ._base_gradient_handler import BaseGradientHandler |
||||
from ...context.parallel_mode import ParallelMode |
||||
|
||||
|
||||
@GRADIENT_HANDLER.register_module |
||||
class DataParallelGradientHandler(BaseGradientHandler): |
||||
"""A helper class to handle all-reduce operations in a data parallel group. |
||||
A all-reduce collective communication will be operated in |
||||
:func:`handle_gradient` among a data parallel group. |
||||
For better performance, it bucketizes the gradients of all parameters that are |
||||
the same type to improve the efficiency of communication. |
||||
""" |
||||
|
||||
def handle_gradient(self): |
||||
"""A method running a all-reduce operation in a data parallel group. |
||||
""" |
||||
# TODO: add memory buffer |
||||
if gpc.data_parallel_size > 1: |
||||
# bucketize and all-reduce |
||||
buckets = {} |
||||
# Pack the buckets. |
||||
for param in self._model.parameters(): |
||||
if param.requires_grad and param.grad is not None: |
||||
tp = param.data.type() |
||||
if tp not in buckets: |
||||
buckets[tp] = [] |
||||
buckets[tp].append(param) |
||||
param.main_grad = param.grad |
||||
|
||||
# For each bucket, all-reduce and copy all-reduced grads. |
||||
for tp in buckets: |
||||
bucket = buckets[tp] |
||||
grads = [param.grad.data for param in bucket] |
||||
coalesced = _flatten_dense_tensors(grads) |
||||
coalesced /= gpc.get_world_size(ParallelMode.DATA) |
||||
|
||||
dist.all_reduce( |
||||
coalesced, group=gpc.get_group(ParallelMode.DATA)) |
||||
for buf, synced in zip(grads, _unflatten_dense_tensors( |
||||
coalesced, grads)): |
||||
buf.copy_(synced) |
@ -0,0 +1,16 @@
|
||||
from colossalai.registry import GRADIENT_HANDLER |
||||
from ._base_gradient_handler import BaseGradientHandler |
||||
|
||||
|
||||
@GRADIENT_HANDLER.register_module |
||||
class ZeROGradientHandler(BaseGradientHandler): |
||||
"""A helper class to handle all-reduce operations in a data parallel group. |
||||
A all-reduce collective communication will be operated in |
||||
:func:`handle_gradient` among a data parallel group. |
||||
This class is specialized with ZeRO optimization. |
||||
""" |
||||
|
||||
def handle_gradient(self): |
||||
"""A method running a all-reduce operation in a data parallel group. |
||||
""" |
||||
self._optimizer.allreduce_gradients() |
@ -0,0 +1,5 @@
|
||||
from ._base_schedule import BaseSchedule |
||||
from ._no_pipeline import NoPipelineSchedule |
||||
from ._pipeline import PipelineSchedule |
||||
|
||||
__all__ = ['BaseSchedule', 'NoPipelineSchedule', 'PipelineSchedule'] |
@ -0,0 +1,129 @@
|
||||
#!/usr/bin/env python |
||||
# -*- encoding: utf-8 -*- |
||||
|
||||
from abc import ABC, abstractmethod |
||||
|
||||
import torch |
||||
|
||||
from colossalai.logging import get_global_dist_logger |
||||
from colossalai.utils import get_current_device |
||||
|
||||
|
||||
class BaseSchedule(ABC): |
||||
"""A basic helper class to control the process of training or evaluation. |
||||
""" |
||||
def __init__(self): |
||||
self.initialized = False |
||||
self.logger = get_global_dist_logger() |
||||
|
||||
@property |
||||
@abstractmethod |
||||
def num_steps(self): |
||||
"""The number of batches in training or evaluation. |
||||
""" |
||||
pass |
||||
|
||||
def initialize(self, |
||||
dataloader=None, |
||||
model=None, |
||||
criterion=None, |
||||
optimizer=None, |
||||
lr_scheduler=None): |
||||
"""Initializes the schedule and set parameters before running. |
||||
|
||||
:param dataloader: DataLoader in training or evaluation |
||||
:param model: The neural network model |
||||
:param criterion: Criterion for calculating loss |
||||
:param optimizer: Optimizer for updating the parameters |
||||
:param lr_scheduler: Learning rate scheduler in the process |
||||
""" |
||||
self.dataloader = dataloader |
||||
assert model is not None, "Schedule requires a model" |
||||
self.model = model |
||||
assert criterion is not None, "Schedule requires a criterion" |
||||
self.criterion = criterion |
||||
assert optimizer is not None, "Schedule requires an optimizer" |
||||
self.optimizer = optimizer |
||||
self.lr_scheduler = lr_scheduler |
||||
self.initialized = True |
||||
|
||||
def check_initialized(self): |
||||
"""Checks whether the schedule is initialized. |
||||
""" |
||||
assert self.initialized, \ |
||||
'Schedule is not initialized. Call schedule.initialize(...) before using it.' |
||||
|
||||
def load_batch(self): |
||||
"""Loads a batch of dataset. It returns the data and labels which are |
||||
already in the same GPU as where the model's. |
||||
|
||||
:return: (data, label) |
||||
:rtype: (Tensor, Tensor) |
||||
""" |
||||
self.check_initialized() |
||||
if self.data_iter is None: |
||||
raise RuntimeError('Dataloader is not defined.') |
||||
data, label = next(self.data_iter) |
||||
return self._move_to_device(data), self._move_to_device(label) |
||||
|
||||
def _move_to_device(self, data): |
||||
if isinstance(data, ( |
||||
tuple, |
||||
list, |
||||
)): |
||||
data = tuple([ |
||||
d.to(get_current_device()).detach() for d in data |
||||
if torch.is_tensor(d) |
||||
]) |
||||
elif torch.is_tensor(data): |
||||
data = data.to(get_current_device()).detach() |
||||
return data |
||||
|
||||
def train(self, dataloader=None, mode=True): |
||||
"""Sets the dataloader to be used and turn the model to |
||||
training or evaluation mode. |
||||
|
||||
:param dataloader: Dataloader to be used |
||||
:param mode: If True, the model will set as training mode. Otherwise, evaluation mode. |
||||
""" |
||||
self.check_initialized() |
||||
if mode: |
||||
self.model.train() |
||||
else: |
||||
self.model.eval() |
||||
if dataloader is not None: |
||||
self.dataloader = dataloader |
||||
self.data_iter = iter(dataloader) |
||||
|
||||
def zero_grad(self, forward_only=False): |
||||
"""Cleans gradients with the optimizer. |
||||
""" |
||||
if not forward_only: |
||||
self.check_initialized() |
||||
self.optimizer.zero_grad() |
||||
|
||||
def get_lr(self): |
||||
"""Returns the current learning rate. |
||||
""" |
||||
if self.lr_scheduler is not None: |
||||
return self.lr_scheduler.get_lr()[0] |
||||
else: |
||||
return self.optimizer.param_groups[0]['lr'] |
||||
|
||||
def step(self): |
||||
"""Updates the parameters and learning rate with the optimizer. |
||||
""" |
||||
self.check_initialized() |
||||
self.optimizer.step() |
||||
# update lr scheduler |
||||
if self.lr_scheduler is not None: |
||||
self.lr_scheduler.step() |
||||
|
||||
@abstractmethod |
||||
def forward_backward_step(self, forward_only=False, return_loss=True): |
||||
"""The process function over a batch of dataset for training or evaluation. |
||||
|
||||
:param forward_only: If True, the process won't include backward. |
||||
:param return_loss: If False, the loss won't be returned. |
||||
""" |
||||
pass |
@ -0,0 +1,185 @@
|
||||
#!/usr/bin/env python |
||||
# -*- encoding: utf-8 -*- |
||||
|
||||
try: |
||||
import apex.amp as apex_amp |
||||
except: |
||||
print('apex is required for mixed precision training') |
||||
try: |
||||
import torch.cuda.amp as torch_amp |
||||
except: |
||||
print('PyTorch amp is not supported with the current PyTorch version') |
||||
|
||||
from colossalai.context import ParallelMode |
||||
from colossalai.core import global_context as gpc |
||||
from colossalai.engine.amp_type import AMP_TYPE |
||||
from colossalai.nn import (ZeroRedundancyOptimizer_Level_2, |
||||
ZeroRedundancyOptimizer_Level_3) |
||||
from ._utils import convert_to_fp16 |
||||
from ._base_schedule import BaseSchedule |
||||
|
||||
|
||||
class NoPipelineSchedule(BaseSchedule): |
||||
"""A helper schedule class for no pipeline parallelism running environment. |
||||
During one process, it loads a batch of dataset and feeds it to the model. |
||||
After getting the output and calculating the loss, it will use :meth:`step` |
||||
to update the parameters if it is in training mode. |
||||
|
||||
:param amp_type: The type of automatic mixed precision |
||||
:param amp_config: The configuration of automatic mixed procision |
||||
:type amp_type: AMP_TYPE |
||||
:type amp_config: dict |
||||
""" |
||||
def __init__( |
||||
self, |
||||
amp_type: AMP_TYPE = None, |
||||
amp_config: dict = None, |
||||
): |
||||
super().__init__() |
||||
|
||||
# mixed precision training |
||||
assert amp_type is None or isinstance(amp_type, AMP_TYPE), \ |
||||
'unrecognised value for argument fp16, it can only be None, torch or apex' |
||||
|
||||
# LSG: check compatibility |
||||
# LSG: torch.cuda.amp and apex.amp cannot be used for tensor parallel |
||||
if gpc.is_initialized(ParallelMode.TENSOR) and gpc.get_world_size( |
||||
ParallelMode.TENSOR) > 1: |
||||
assert amp_type != AMP_TYPE.TORCH and amp_type != AMP_TYPE.APEX, \ |
||||
'You can only AMP_TYPE.PARALLEL for tensor parallel training' |
||||
self.use_zero_level_2_3 = False |
||||
|
||||
if amp_type is not None: |
||||
self.fp16 = True |
||||
self.amp_type = amp_type |
||||
|
||||
if amp_config is not None: |
||||
assert isinstance(amp_config, dict), \ |
||||
f'expected argument fp16_config to be type dictionary, but got {type(amp_config)}' |
||||
|
||||
if self.amp_type == AMP_TYPE.TORCH: |
||||
# torch apex |
||||
if amp_config is None: |
||||
amp_config = dict() |
||||
self.amp_cfg = amp_config |
||||
elif self.amp_type == AMP_TYPE.APEX: |
||||
# apex amp |
||||
if amp_config is None: |
||||
amp_config = dict(opt_level='O2') |
||||
self.logger.warning( |
||||
'apex is deprecated, please consider using torch.cuda.amp instead.' |
||||
) |
||||
self.amp_cfg = amp_config |
||||
elif self.amp_type == AMP_TYPE.PARALLEL: |
||||
# use fp16 optimizer for tensor parallel training |
||||
if amp_config is None: |
||||
amp_config = dict() |
||||
self.amp_cfg = amp_config |
||||
else: |
||||
self.fp16 = False |
||||
self.amp_type = None |
||||
|
||||
@property |
||||
def num_steps(self): |
||||
return len(self.dataloader) |
||||
|
||||
def initialize(self, |
||||
dataloader, |
||||
model, |
||||
criterion, |
||||
optimizer, |
||||
lr_scheduler=None): |
||||
super().initialize(dataloader, |
||||
model, |
||||
criterion, |
||||
optimizer, |
||||
lr_scheduler=lr_scheduler) |
||||
if isinstance(self.optimizer, (ZeroRedundancyOptimizer_Level_2, |
||||
ZeroRedundancyOptimizer_Level_3)): |
||||
self.use_zero_level_2_3 = True |
||||
assert self.amp_type != AMP_TYPE.PARALLEL, 'ZeRO Level 2 and 3 are mutually exclusive with AMP_TYPE.PARALLEL' |
||||
|
||||
if self.fp16: |
||||
if self.amp_type == AMP_TYPE.TORCH: |
||||
self._torch_amp_scaler = torch_amp.GradScaler(**self.amp_cfg) |
||||
elif self.amp_type == AMP_TYPE.APEX: |
||||
self.model, self.optimizer = apex_amp.initialize( |
||||
self.model, self.optimizer, **self.amp_cfg) |
||||
|
||||
def forward_backward_step(self, forward_only=False, return_loss=True): |
||||
"""The process function that loads loads a batch of dataset and feeds it to the model. |
||||
The returned labels and loss will None if :attr:`return_loss` is False. |
||||
|
||||
:return: (output, label, loss) |
||||
""" |
||||
assert forward_only or return_loss, \ |
||||
'The argument \'return_loss\' has to be True when \'forward_only\' is False, but got False.' |
||||
|
||||
data, label = self.load_batch() |
||||
loss = None |
||||
|
||||
# LSG: leave for debug, make sure dataloader is deterministic |
||||
# if forward_only: |
||||
# img = data[0] |
||||
# rank = gpc.get_local_rank(ParallelMode.DATA) |
||||
# world_size = gpc.get_world_size(ParallelMode.DATA) |
||||
# group = gpc.get_group(ParallelMode.DATA) |
||||
# input_list = [img.clone() for _ in range(world_size)] |
||||
# output_list = [torch.empty_like(img) for _ in range(world_size)] |
||||
# output_list[rank] = img.clone() |
||||
# dist.all_to_all(output_tensor_list=output_list, input_tensor_list=input_list, group=group) |
||||
# assert torch.equal(output_list[0], output_list[1]) # and torch.equal(output_list[1], output_list[2]) |
||||
|
||||
# forward |
||||
if self.fp16 and self.amp_type == AMP_TYPE.TORCH: |
||||
with torch_amp.autocast(): |
||||
output = self.model(*data) |
||||
if not isinstance(output, (tuple, list)): |
||||
output = (output,) |
||||
if return_loss: |
||||
loss = self.criterion(*output, *label) |
||||
else: |
||||
if self.use_zero_level_2_3 or self.amp_type == AMP_TYPE.PARALLEL: |
||||
data = convert_to_fp16(data) |
||||
|
||||
output = self.model(*data) |
||||
if not isinstance(output, (tuple, list)): |
||||
output = (output,) |
||||
if return_loss: |
||||
loss = self.criterion(*output, *label) |
||||
|
||||
if not forward_only: |
||||
# backward |
||||
if self.use_zero_level_2_3: |
||||
self.optimizer.backward(loss) |
||||
elif self.fp16: |
||||
if self.amp_type == AMP_TYPE.APEX: |
||||
with apex_amp.scale_loss(loss, |
||||
self.optimizer) as scaled_loss: |
||||
scaled_loss.backward() |
||||
elif self.amp_type == AMP_TYPE.TORCH: |
||||
self._torch_amp_scaler.scale(loss).backward() |
||||
elif self.amp_type == AMP_TYPE.PARALLEL: |
||||
loss = self.optimizer.scale_loss(loss) |
||||
loss.backward() |
||||
# scale back to display the original value in logs |
||||
loss.div_(self.optimizer.grad_scaler.scale) |
||||
else: |
||||
loss.backward() |
||||
|
||||
if return_loss: |
||||
return output, label, loss |
||||
else: |
||||
return output, None, None |
||||
|
||||
def step(self): |
||||
# step optimizer |
||||
if self.fp16 and self.amp_type == AMP_TYPE.TORCH: |
||||
self._torch_amp_scaler.step(self.optimizer) |
||||
self._torch_amp_scaler.update() |
||||
else: |
||||
self.optimizer.step() |
||||
|
||||
# update lr scheduler |
||||
if self.lr_scheduler is not None: |
||||
self.lr_scheduler.step() |
@ -0,0 +1,316 @@
|
||||
#!/usr/bin/env python |
||||
# -*- encoding: utf-8 -*- |
||||
|
||||
from typing import Union |
||||
|
||||
import torch.cuda |
||||
import torch.distributed as dist |
||||
from torch import Tensor |
||||
|
||||
from colossalai.communication import * |
||||
from colossalai.context.parallel_mode import ParallelMode |
||||
from colossalai.core import global_context as gpc |
||||
from colossalai.nn import (ZeroRedundancyOptimizer_Level_2, |
||||
ZeroRedundancyOptimizer_Level_3) |
||||
from colossalai.utils import get_current_device |
||||
from ._base_schedule import BaseSchedule |
||||
from ._utils import convert_to_fp16 |
||||
from ..amp_type import AMP_TYPE |
||||
|
||||
|
||||
def squeeze(x: Union[Tensor, tuple, list]): |
||||
if isinstance(x, (tuple, list)): |
||||
return x[0] |
||||
else: |
||||
return x |
||||
|
||||
|
||||
class PipelineSchedule(BaseSchedule): |
||||
"""A helper schedule class for pipeline parallelism running environment. |
||||
It uses non-interleaved 1F1B strategy. Other properties are similar as |
||||
:class:`NoPipelineSchedule`. |
||||
|
||||
:param num_microbatches: The number of microbatches |
||||
:param amp_type: The type of automatic mixed precision |
||||
:param amp_config: The configuration of automatic mixed procision |
||||
:type num_microbatches: int |
||||
:type amp_type: AMP_TYPE |
||||
:type amp_config: dict |
||||
""" |
||||
|
||||
def __init__(self, |
||||
num_microbatches, |
||||
amp_type: AMP_TYPE = None, |
||||
amp_config: dict = None): |
||||
super().__init__() |
||||
|
||||
self.num_microbatches = num_microbatches |
||||
self.data_sync = True # close after making sure data is identical |
||||
|
||||
# amp |
||||
# LSGL: amp_config is not used, but leave here for future extension |
||||
self.amp_type = amp_type |
||||
self.amp_config = amp_config |
||||
|
||||
if self.amp_type is not None: |
||||
assert self.amp_type == AMP_TYPE.PARALLEL, 'We only support AMP_TYPE.PARALLEL for pipeline training for now' |
||||
|
||||
def _move_to_device(self, data): |
||||
if isinstance(data, ( |
||||
tuple, |
||||
list, |
||||
)): |
||||
assert len(data) == 1, "Data tuple's length in pipeline should be 1" |
||||
data = data[0] |
||||
assert torch.is_tensor(data), "Data in pipeline should be tensor" |
||||
data = data.to(get_current_device()).detach() |
||||
return data |
||||
|
||||
def _sync_data(self): |
||||
if gpc.is_first_rank(ParallelMode.PIPELINE): |
||||
src_rank = gpc.get_global_rank() |
||||
dist.broadcast( |
||||
tensor=self.batch_data, |
||||
src=src_rank, |
||||
group=gpc.get_group(ParallelMode.PIPELINE_PREV) |
||||
) |
||||
dist.broadcast( |
||||
tensor=self.batch_label, |
||||
src=src_rank, |
||||
group=gpc.get_group(ParallelMode.PIPELINE_PREV) |
||||
) |
||||
if gpc.is_last_rank(ParallelMode.PIPELINE): |
||||
src_rank = gpc.get_next_global_rank(ParallelMode.PIPELINE) |
||||
dist.broadcast( |
||||
tensor=self.batch_data, |
||||
src=src_rank, |
||||
group=gpc.get_group(ParallelMode.PIPELINE_NEXT) |
||||
) |
||||
dist.broadcast( |
||||
tensor=self.batch_label, |
||||
src=src_rank, |
||||
group=gpc.get_group(ParallelMode.PIPELINE_NEXT) |
||||
) |
||||
|
||||
# Pipeline schedule just puts data in memory |
||||
def load_batch(self): |
||||
self.check_initialized() |
||||
if self.data_iter is None: |
||||
raise RuntimeError('Dataloader is not defined.') |
||||
self.batch_pos = 0 |
||||
data, label = next(self.data_iter) |
||||
self.batch_data, self.batch_label = \ |
||||
self._move_to_device(data), self._move_to_device(label) |
||||
batch_size = self.batch_data.shape[0] |
||||
assert batch_size % self.num_microbatches == 0, \ |
||||
"Batch size should divided by the number of microbatches" |
||||
self.microbatch_size = batch_size // self.num_microbatches |
||||
if self.data_sync: |
||||
self._sync_data() |
||||
|
||||
def _get_data_slice(self, tensor): |
||||
return tensor[self.batch_pos: self.batch_pos + self.microbatch_size] |
||||
|
||||
def load_micro_batch(self): |
||||
data = self._get_data_slice(self.batch_data) |
||||
label = self._get_data_slice(self.batch_label) |
||||
self.batch_pos += self.microbatch_size |
||||
return (data,), (label,) |
||||
|
||||
@property |
||||
def num_steps(self): |
||||
return len(self.dataloader) |
||||
|
||||
def initialize(self, |
||||
dataloader, |
||||
model, |
||||
criterion, |
||||
optimizer, |
||||
lr_scheduler=None): |
||||
super().initialize(dataloader, |
||||
model, |
||||
criterion, |
||||
optimizer, |
||||
lr_scheduler=lr_scheduler) |
||||
if isinstance(self.optimizer, (ZeroRedundancyOptimizer_Level_2, |
||||
ZeroRedundancyOptimizer_Level_3)): |
||||
raise TypeError( |
||||
"Pipeline schedule is currently not compatible with ZeRO Level 2 and Level 3" |
||||
) |
||||
|
||||
# LSG: set default dtype to fp16 for communication |
||||
if self.amp_type == AMP_TYPE.PARALLEL: |
||||
torch.set_default_dtype(torch.half) |
||||
self.logger.info( |
||||
'default tensor dtype is set to torch.half for fp16 training', |
||||
ranks=[0]) |
||||
|
||||
def forward_step(self, input_tensor, return_tensors, return_loss=True): |
||||
"""Forward step for passed-in model. If it is the first stage, the input tensor |
||||
is obtained from data_iterator, otherwise the passed-in input_tensor is used. |
||||
Returns output tensor. This is a helper function and can be ignored by users. |
||||
""" |
||||
|
||||
if input_tensor is None: |
||||
input_tensor, label = self.load_micro_batch() |
||||
if self.amp_type == AMP_TYPE.PARALLEL: |
||||
input_tensor = convert_to_fp16(input_tensor) |
||||
input_tensor = squeeze(input_tensor) |
||||
output_tensor = self.model(input_tensor) |
||||
output_tensor = squeeze(output_tensor) |
||||
|
||||
if gpc.is_last_rank(ParallelMode.PIPELINE): |
||||
if return_loss: |
||||
input_tensor, label = self.load_micro_batch() |
||||
loss_reduced = self.criterion(output_tensor, * |
||||
label) / self.num_microbatches |
||||
return_tensors.append( |
||||
tuple((output_tensor, label[0], loss_reduced))) |
||||
return loss_reduced |
||||
else: |
||||
return_tensors.append(output_tensor) |
||||
return output_tensor |
||||
|
||||
else: |
||||
return output_tensor |
||||
|
||||
def backward_step(self, input_tensor, output_tensor, output_tensor_grad): |
||||
"""Backward step through the passed-in output tensor. If it is the last stage, the |
||||
output_tensor_grad is None, otherwise it is the gradients with respect to stage's output tensor. |
||||
Returns the gradients with respect to the input tensor (None if first stage). |
||||
This is a helper function and can be ignored by users. |
||||
""" |
||||
|
||||
# Retain the grad on the input_tensor. |
||||
if input_tensor is not None: |
||||
input_tensor.retain_grad() |
||||
|
||||
# Backward pass. |
||||
if output_tensor_grad is None and self.amp_type == AMP_TYPE.PARALLEL: |
||||
output_tensor = self.optimizer.scale_loss(output_tensor) |
||||
torch.autograd.backward(output_tensor, grad_tensors=output_tensor_grad) |
||||
|
||||
# Collect the grad of the input_tensor. |
||||
input_tensor_grad = None |
||||
if input_tensor is not None: |
||||
input_tensor_grad = input_tensor.grad |
||||
|
||||
return input_tensor_grad |
||||
|
||||
def forward_backward_step(self, forward_only=True, return_loss=True): |
||||
"""Runs non-interleaved 1F1B schedule, with communication between pipeline stages. |
||||
Returns a tuple with losses if the last stage, an empty tuple otherwise. |
||||
|
||||
:return: (output, label, loss) |
||||
""" |
||||
|
||||
assert forward_only or return_loss, \ |
||||
'The argument \'return_loss\' has to be True when \'forward_only\' is False, but got False.' |
||||
|
||||
self.load_batch() |
||||
num_warmup_microbatches = \ |
||||
(gpc.get_world_size(ParallelMode.PIPELINE) - |
||||
gpc.get_local_rank(ParallelMode.PIPELINE) - 1) |
||||
num_warmup_microbatches = min(num_warmup_microbatches, |
||||
self.num_microbatches) |
||||
num_microbatches_remaining = self.num_microbatches - num_warmup_microbatches |
||||
|
||||
# Input, output tensors only need to be saved when doing backward passes |
||||
input_tensors = None |
||||
output_tensors = None |
||||
if not forward_only: |
||||
input_tensors = [] |
||||
output_tensors = [] |
||||
return_tensors = [] |
||||
|
||||
# Used for tensor meta information communication |
||||
ft_shape = None |
||||
bt_shape = None |
||||
fs_checker = True |
||||
|
||||
# Run warmup forward passes. |
||||
for i in range(num_warmup_microbatches): |
||||
if not gpc.is_first_rank(ParallelMode.PIPELINE): |
||||
ft_shape = recv_tensor_meta(ft_shape) |
||||
input_tensor = recv_forward(ft_shape) |
||||
output_tensor = self.forward_step(input_tensor, |
||||
return_tensors, |
||||
return_loss=return_loss) |
||||
if not gpc.is_last_rank(ParallelMode.PIPELINE): |
||||
bt_shape = output_tensor.shape |
||||
fs_checker = send_tensor_meta(output_tensor, fs_checker) |
||||
send_forward(output_tensor) |
||||
|
||||
if not forward_only: |
||||
input_tensors.append(input_tensor) |
||||
output_tensors.append(output_tensor) |
||||
|
||||
# Before running 1F1B, need to receive first forward tensor. |
||||
# If all microbatches are run in warmup / cooldown phase, then no need to |
||||
# receive this tensor here. |
||||
if num_microbatches_remaining > 0: |
||||
if not gpc.is_first_rank(ParallelMode.PIPELINE): |
||||
ft_shape = recv_tensor_meta(ft_shape) |
||||
input_tensor = recv_forward(ft_shape) |
||||
|
||||
# Run 1F1B in steady state. |
||||
for i in range(num_microbatches_remaining): |
||||
last_iteration = (i == (num_microbatches_remaining - 1)) |
||||
|
||||
output_tensor = self.forward_step(input_tensor, |
||||
return_tensors, |
||||
return_loss=return_loss) |
||||
if forward_only: |
||||
send_forward(output_tensor) |
||||
|
||||
if not last_iteration: |
||||
input_tensor = recv_forward(ft_shape) |
||||
|
||||
else: |
||||
output_tensor_grad = send_forward_recv_backward( |
||||
output_tensor, bt_shape) |
||||
|
||||
# Add input_tensor and output_tensor to end of list. |
||||
input_tensors.append(input_tensor) |
||||
output_tensors.append(output_tensor) |
||||
|
||||
# Pop input_tensor and output_tensor from the start of the list for |
||||
# the backward pass. |
||||
input_tensor = input_tensors.pop(0) |
||||
output_tensor = output_tensors.pop(0) |
||||
|
||||
input_tensor_grad = self.backward_step(input_tensor, |
||||
output_tensor, |
||||
output_tensor_grad) |
||||
|
||||
if last_iteration: |
||||
input_tensor = None |
||||
send_backward(input_tensor_grad) |
||||
else: |
||||
input_tensor = send_backward_recv_forward( |
||||
input_tensor_grad, ft_shape) |
||||
|
||||
# Run cooldown backward passes. |
||||
if not forward_only: |
||||
for i in range(num_warmup_microbatches): |
||||
input_tensor = input_tensors.pop(0) |
||||
output_tensor = output_tensors.pop(0) |
||||
|
||||
output_tensor_grad = recv_backward(bt_shape) |
||||
|
||||
input_tensor_grad = self.backward_step(input_tensor, |
||||
output_tensor, |
||||
output_tensor_grad) |
||||
|
||||
send_backward(input_tensor_grad) |
||||
|
||||
if len(return_tensors) > 0: |
||||
if return_loss: |
||||
output, label, loss = tuple(map(list, zip(*return_tensors))) |
||||
return (torch.cat(output, dim=0), |
||||
torch.cat(label, dim=0), |
||||
sum(loss)) |
||||
else: |
||||
return tuple((torch.cat(return_tensors, dim=0), None, None)) |
||||
else: |
||||
return tuple((None, None, None)) |
@ -0,0 +1,16 @@
|
||||
#!/usr/bin/env python |
||||
# -*- encoding: utf-8 -*- |
||||
|
||||
from typing import Union, List |
||||
|
||||
from torch import Tensor |
||||
|
||||
|
||||
def convert_to_fp16(data: Union[Tensor, List[Tensor]]): |
||||
if isinstance(data, Tensor): |
||||
ret = data.half() |
||||
elif isinstance(data, (list, tuple)): |
||||
ret = [val.half() for val in data] |
||||
else: |
||||
raise TypeError(f"Expected argument 'data' to be a Tensor or a list/tuple of Tensor, but got {type(data)}") |
||||
return ret |
@ -0,0 +1,371 @@
|
||||
#!/usr/bin/env python |
||||
# -*- encoding: utf-8 -*- |
||||
|
||||
import argparse |
||||
import pprint |
||||
import random |
||||
from pathlib import Path |
||||
from typing import Callable, Iterable, Optional, Union |
||||
|
||||
import numpy as np |
||||
import torch |
||||
from torch.utils.data import DataLoader |
||||
|
||||
from colossalai.engine import AMP_TYPE, NoPipelineSchedule, PipelineSchedule |
||||
from colossalai.logging import get_global_dist_logger, init_global_dist_logger |
||||
from colossalai.nn import DataParallelSampler |
||||
from colossalai.nn.model.base_model import BaseModel |
||||
from .builder import (ModelInitializer, build_dataset, build_loss, |
||||
build_lr_scheduler, build_model, build_optimizer, |
||||
build_optimizer_wrapper) |
||||
from .context import Config, ParallelMode |
||||
from .core import global_context as gpc |
||||
from .utils import get_current_device, sync_model_param_in_dp |
||||
|
||||
|
||||
def parse_args(): |
||||
'''Reads user command line and uses an argument parser to parse the input arguments. |
||||
Input arguments include configuration, host, port, world size, local rank, backend for torch.distributed. |
||||
|
||||
:return: call the parse arguments function |
||||
:rtype: Namespace |
||||
''' |
||||
parser = argparse.ArgumentParser() |
||||
parser.add_argument('--config', type=str, help='path to the config file') |
||||
parser.add_argument('--host', |
||||
type=str, |
||||
default=None, |
||||
help='the master address for distributed training') |
||||
parser.add_argument('--port', |
||||
type=str, |
||||
default=None, |
||||
help='the master port for distributed training') |
||||
parser.add_argument('--world_size', type=int, help='world size for ') |
||||
parser.add_argument('--local_rank', |
||||
type=int, |
||||
help='rank for the default process group') |
||||
parser.add_argument('--backend', |
||||
type=str, |
||||
default='nccl', |
||||
help='backend for torch.distributed') |
||||
return parser.parse_args() |
||||
|
||||
|
||||
def init_dist(config: Union[str, dict] = None, |
||||
local_rank: int = None, |
||||
world_size: int = None, |
||||
host: str = None, |
||||
port: str = None, |
||||
backend: str = None): |
||||
'''This function first parses the configuration arguments, using :func:parse_args() in case one of the input arguments are not given. |
||||
Then initialize and set distributed environment by calling global_context's functions. |
||||
|
||||
:param config: config file or config file path are both acceptable |
||||
:type config: Union[str, dict], optional |
||||
:param local_rank: rank for the default process group, defaults to None |
||||
:type local_rank: int, optional |
||||
:param world_size: world size of GPUs, defaults to None |
||||
:type world_size: int, optional |
||||
:param host: the master address for distributed training, defaults to None |
||||
:type host: str, optional |
||||
:param port: the master port for distributed training, defaults to None |
||||
:type port: str, optional |
||||
:param backend: backend for torch.distributed, defaults to None |
||||
:type backend: str, optional |
||||
:raises Exception: raise exception when config type is wrong |
||||
''' |
||||
args = [config, local_rank, world_size, host, port, backend] |
||||
arg_given = [arg is not None for arg in args] |
||||
|
||||
if not all(arg_given): |
||||
args = parse_args() |
||||
|
||||
if config is None: |
||||
config = args.config |
||||
if local_rank is None: |
||||
local_rank = args.local_rank |
||||
if world_size is None: |
||||
world_size = args.world_size |
||||
if host is None: |
||||
host = args.host |
||||
if port is None: |
||||
port = args.port |
||||
if backend is None: |
||||
backend = args.backend |
||||
args = Config( |
||||
dict(config=config, |
||||
host=host, |
||||
port=port, |
||||
world_size=world_size, |
||||
local_rank=local_rank, |
||||
backend=backend)) |
||||
|
||||
# set distributed settings |
||||
dist_args = Config( |
||||
dict(local_rank=args.local_rank, |
||||
world_size=args.world_size, |
||||
backend=args.backend)) |
||||
|
||||
gpc.set_dist_args(dist_args) |
||||
|
||||
# set config |
||||
if isinstance(args.config, dict): |
||||
cfg = args.config |
||||
elif isinstance(args.config, (str, Path)): |
||||
cfg = Config.from_file(args.config) |
||||
else: |
||||
raise Exception('Config type error: {}'.format(type(args.config))) |
||||
gpc.load_config(cfg) |
||||
|
||||
# init dist groups |
||||
gpc.init_global_dist(args.host, args.port) |
||||
gpc.init_parallel_groups() |
||||
|
||||
# init dist logger |
||||
init_global_dist_logger() |
||||
|
||||
# set cuda device |
||||
if torch.cuda.is_available(): |
||||
gpc.set_device() |
||||
|
||||
|
||||
def get_dataloader(dataset, seed=1024, add_sampler_if_possible=False, **kwargs): |
||||
'''Set up a deterministic dataloader (also configure seed workers, samplers and whether shuffle or not) |
||||
|
||||
.. note: when pipeline parallel is enabled, shuffle cannot be True |
||||
as it will result in mismatch between input data on the 1st |
||||
stage and label on the last stage |
||||
|
||||
:param dataset: a :class:utils.data.dataset dataset |
||||
:param seed: random worker seed, defaults to 1024 |
||||
:type seed: int, optional |
||||
:param add_sampler_if_possible: [description], defaults to False |
||||
:type add_sampler_if_possible: bool, optional |
||||
:return: a :class:utils.data.dataset dataloader |
||||
:rtype: torch.utils.data.dataset |
||||
''' |
||||
_kwargs = kwargs.copy() |
||||
if 'shuffle' in _kwargs: |
||||
shuffle = _kwargs.pop('shuffle') |
||||
else: |
||||
shuffle = False |
||||
|
||||
if add_sampler_if_possible and gpc.is_initialized(ParallelMode.DATA) and gpc.get_world_size(ParallelMode.DATA) > 1: |
||||
sampler = DataParallelSampler(dataset, shuffle=shuffle) |
||||
else: |
||||
sampler = None |
||||
|
||||
# Deterministic dataloader |
||||
def seed_worker(worker_id): |
||||
worker_seed = seed |
||||
np.random.seed(worker_seed) |
||||
torch.manual_seed(worker_seed) |
||||
random.seed(worker_seed) |
||||
|
||||
if sampler is None: |
||||
return DataLoader(dataset, |
||||
worker_init_fn=seed_worker, |
||||
shuffle=shuffle, |
||||
**_kwargs) |
||||
else: |
||||
return DataLoader(dataset, |
||||
sampler=sampler, |
||||
worker_init_fn=seed_worker, |
||||
**_kwargs) |
||||
|
||||
|
||||
def initialize(config: Union[str, dict] = None, |
||||
local_rank: int = None, |
||||
world_size: int = None, |
||||
host: str = None, |
||||
port: str = None, |
||||
backend: str = None, |
||||
train_dataloader: Optional[Union[Iterable, Callable]] = None, |
||||
test_dataloader: Optional[Union[Iterable, Callable]] = None, |
||||
): |
||||
'''Core function that initializes distributed environment, logger, cudnn, data, model, loss function, optimizer, and lr_scheduler(their configs are in gpc.config). |
||||
|
||||
:param config: config file or config file path are both acceptable |
||||
:type config: Union[str, dict], optional |
||||
:param local_rank: rank for the default process group, defaults to None |
||||
:type local_rank: int, optional |
||||
:param world_size: world size of GPUs, defaults to None |
||||
:type world_size: int, optional |
||||
:param host: the master address for distributed training, defaults to None |
||||
:type host: str, optional |
||||
:param port: the master port for distributed training, defaults to None |
||||
:type port: str, optional |
||||
:param backend: backend for torch.distributed, defaults to None |
||||
:type backend: str, optional |
||||
:param train_dataloader: If None, the config is used to build a dataloder; Else, it should be a dataloader object or a function with no arguments which can build a dataloader, defaults to None |
||||
:type train_dataloader: Optional[Union[Iterable, Callable]], optional |
||||
:param test_dataloader: If None, the config is used to build a dataloder; Else, it should be a dataloader object or a function with no arguments which can build a dataloader, defaults to None |
||||
:type test_dataloader: Optional[Union[Iterable, Callable]], optional |
||||
:return: (model, train_dataloader, test_dataloader, criterion, optimizer, schedule, lr_scheduler) |
||||
:rtype: tuple |
||||
''' |
||||
# initialize distributed environment |
||||
init_dist(config=config, |
||||
local_rank=local_rank, |
||||
world_size=world_size, |
||||
host=host, |
||||
port=port, |
||||
backend=backend) |
||||
|
||||
# init logger |
||||
logger = get_global_dist_logger() |
||||
logger.info(f'Distributed environment is initialized, ' |
||||
f'data parallel size: {gpc.data_parallel_size}, pipeline parallel size: {gpc.pipeline_parallel_size}, ' |
||||
f'tensor parallel size: {gpc.tensor_parallel_size}', ranks=[0]) |
||||
|
||||
# print config |
||||
logger.info(f"\n========== Your Config ========\n" |
||||
f"{pprint.pformat(gpc.config)}\n" |
||||
f"================================", ranks=[0]) |
||||
|
||||
# cudnn |
||||
cudnn_benchmark = gpc.config.get('cudnn_benchmark', True) |
||||
cudnn_deterministic = gpc.config.get('cudnn_deterministic', False) |
||||
torch.backends.cudnn.benchmark = cudnn_benchmark |
||||
torch.backends.cudnn.deterministic = cudnn_deterministic |
||||
logger.info( |
||||
f"cuDNN benchmark = {cudnn_benchmark}, deterministic = {cudnn_deterministic}", ranks=[0]) |
||||
|
||||
# set seed, cuda seed is only set when cuda is avail |
||||
gpc.set_seed() |
||||
|
||||
# return_items = list() |
||||
|
||||
# check fp16 and zero |
||||
should_convert_model_to_half = False |
||||
should_wrap_fp16_optimizer = False |
||||
should_wrap_zero_optimizer_level_2_3 = False |
||||
|
||||
if hasattr(gpc.config, 'fp16'): |
||||
fp16_mode = gpc.config.fp16.mode |
||||
if fp16_mode == AMP_TYPE.PARALLEL: |
||||
should_convert_model_to_half = True |
||||
should_wrap_fp16_optimizer = True |
||||
|
||||
if hasattr(gpc.config, 'zero'): |
||||
should_wrap_zero_optimizer_level_2_3 = True |
||||
zero_type = gpc.config.zero.type |
||||
if zero_type in ['ZeroRedundancyOptimizer_Level_2', 'ZeroRedundancyOptimizer_Level_3']: |
||||
should_convert_model_to_half = True |
||||
assert not should_wrap_fp16_optimizer, \ |
||||
'AMP_TYPE.PARALLEL is mutually exclusive with zero level 2 and 3' |
||||
|
||||
# build model |
||||
logger.info('Building model ...', ranks=[0]) |
||||
assert hasattr( |
||||
gpc.config, 'model'), "Build error: configuration 'model' is missing" |
||||
if gpc.pipeline_parallel_size > 1: |
||||
model = ModelInitializer(gpc.config.model, 1, verbose=True) |
||||
model = model.model_initialize() |
||||
else: |
||||
model = build_model(gpc.config.model) |
||||
if isinstance(model, BaseModel): |
||||
model.build_from_cfg() |
||||
model = model.to(get_current_device()) |
||||
sync_model_param_in_dp(model) |
||||
logger.info('Model is created', ranks=[0]) |
||||
|
||||
if should_convert_model_to_half: |
||||
model = model.half() |
||||
logger.info("Model is cast to fp16", ranks=[0]) |
||||
|
||||
# training data |
||||
if callable(train_dataloader): |
||||
logger.info( |
||||
f'Build train data loader from {train_dataloader}', ranks=[0]) |
||||
train_dataloader = train_dataloader() |
||||
if train_dataloader is None and hasattr(gpc.config, 'train_data'): |
||||
logger.info('Preparing data ...', ranks=[0]) |
||||
# assert hasattr(gpc.config, 'train_data'), "Build error: configuration 'train_data' is missing." |
||||
train_dataset = build_dataset(gpc.config.train_data.dataset) |
||||
logger.info('Train dataset is ready.', ranks=[0]) |
||||
|
||||
train_dataloader = get_dataloader(train_dataset, |
||||
gpc.config.get('seed', 1024), |
||||
True, |
||||
**gpc.config.train_data.dataloader, |
||||
) |
||||
logger.info( |
||||
f'Loaded {len(train_dataset)} samples in {len(train_dataloader)} batches for training', ranks=[0]) |
||||
|
||||
if callable(test_dataloader): |
||||
logger.info( |
||||
f'Build test data loader from {test_dataloader}', ranks=[0]) |
||||
test_dataloader = test_dataloader() |
||||
# testing data, allowed to be None |
||||
if test_dataloader is None and hasattr(gpc.config, 'test_data'): |
||||
test_dataset = build_dataset(gpc.config.test_data.dataset) |
||||
test_dataloader = get_dataloader( |
||||
test_dataset, add_sampler_if_possible=True, **gpc.config.test_data.dataloader) |
||||
logger.info( |
||||
f'Loaded {len(test_dataset)} samples in {len(test_dataloader)} batches for testing', ranks=[0]) |
||||
|
||||
# build loss function |
||||
assert hasattr(gpc.config, 'loss'), \ |
||||
'Build error: configuration \'loss\' is missing.' |
||||
criterion = build_loss(gpc.config.loss) |
||||
logger.info('Loss function is created', ranks=[0]) |
||||
|
||||
# build optimizer |
||||
assert hasattr(gpc.config, 'optimizer'), \ |
||||
"Build error: configuration 'optimizer' is missing." |
||||
optim_type = gpc.config.optimizer.type |
||||
is_pytorch_native_zero_level_1 = optim_type == 'ZeroRedundancyOptimizer' |
||||
if is_pytorch_native_zero_level_1: |
||||
original_cfg_copy = gpc.config.optimizer.copy() |
||||
original_cfg_copy.pop('type') |
||||
cfg = dict(type=optim_type, process_group=gpc.get_group( |
||||
ParallelMode.DATA), **original_cfg_copy) |
||||
optimizer = build_optimizer(cfg, model) |
||||
else: |
||||
optimizer = build_optimizer(gpc.config.optimizer, model) |
||||
|
||||
if should_wrap_zero_optimizer_level_2_3: |
||||
optimizer = build_optimizer_wrapper(gpc.config.zero, optimizer, model) |
||||
|
||||
if should_wrap_fp16_optimizer: |
||||
# replace the field mode with type |
||||
fp16_cfg = gpc.config.fp16.copy() |
||||
amp_type = fp16_cfg.pop('mode') |
||||
assert amp_type == AMP_TYPE.PARALLEL, 'FP Optimizer should only be used for AMP_TYPE.PARALLEL' |
||||
fp16_cfg['type'] = 'FP16Optimizer' |
||||
optimizer = build_optimizer_wrapper(fp16_cfg, optimizer) |
||||
logger.info('Optimizer is created', ranks=[0]) |
||||
|
||||
lr_scheduler = None |
||||
if hasattr(gpc.config, 'lr_scheduler'): |
||||
if hasattr(gpc.config, 'num_steps'): |
||||
total_steps = gpc.config.num_steps |
||||
elif hasattr(gpc.config, 'num_epochs'): |
||||
total_steps = int(gpc.config.num_epochs * len(train_dataloader)) |
||||
else: |
||||
raise Exception( |
||||
'Please specify training stopping criterion num_steps or num_epochs in your configuration.' |
||||
) |
||||
lr_scheduler = build_lr_scheduler(gpc.config.lr_scheduler, optimizer, |
||||
total_steps, len(train_dataloader)) |
||||
logger.info('Learning rate scheduler is created', ranks=[0]) |
||||
|
||||
# pipeline or no pipeline schedule |
||||
if hasattr(gpc.config, 'fp16'): |
||||
amp_type = gpc.config.fp16.mode |
||||
amp_cfg = gpc.config.fp16.copy() |
||||
amp_cfg.pop('mode') |
||||
else: |
||||
amp_type = None |
||||
amp_cfg = None |
||||
|
||||
if gpc.is_initialized(ParallelMode.PIPELINE) and gpc.get_world_size(ParallelMode.PIPELINE) > 1: |
||||
assert hasattr(gpc.config, |
||||
'schedule'), "Config 'schedule' not found in your configuration file for pipeline parallel training" |
||||
schedule = PipelineSchedule( |
||||
amp_type=amp_type, amp_config=amp_cfg, **gpc.config.schedule.copy()) |
||||
else: |
||||
schedule = NoPipelineSchedule(amp_type=amp_type, amp_config=amp_cfg) |
||||
|
||||
return model, train_dataloader, test_dataloader, criterion, optimizer, schedule, lr_scheduler |
@ -0,0 +1,26 @@
|
||||
from colossalai.core import global_context as gpc |
||||
from .logging import DistributedLogger |
||||
|
||||
__all__ = ['get_global_dist_logger', 'get_dist_logger', 'DistributedLogger', 'init_global_dist_logger'] |
||||
|
||||
_GLOBAL_LOGGER: DistributedLogger = None |
||||
|
||||
|
||||
def get_dist_logger(name, level='INFO', root_path: str = None, mode='a'): |
||||
return DistributedLogger(name=name, level=level, root_path=root_path, mode=mode) |
||||
|
||||
|
||||
def get_global_dist_logger(): |
||||
assert _GLOBAL_LOGGER is not None, 'Global distributed logger is not initialized' |
||||
return _GLOBAL_LOGGER |
||||
|
||||
|
||||
def init_global_dist_logger(): |
||||
rank = gpc.get_global_rank() |
||||
if hasattr(gpc.config, 'logging'): |
||||
logger = get_dist_logger(name=f'rank_{rank}', **gpc.config.logging) |
||||
else: |
||||
logger = get_dist_logger(name=f'rank_{rank}', level='INFO') |
||||
global _GLOBAL_LOGGER |
||||
assert _GLOBAL_LOGGER is None, 'Global distributed logger has already been initialized' |
||||
_GLOBAL_LOGGER = logger |
@ -0,0 +1,97 @@
|
||||
#!/usr/bin/env python |
||||
# -*- encoding: utf-8 -*- |
||||
|
||||
import logging |
||||
from pathlib import Path |
||||
|
||||
from colossalai.context.parallel_mode import ParallelMode |
||||
from colossalai.core import global_context as gpc |
||||
|
||||
_FORMAT = 'colossalai - %(name)s - %(asctime)s %(levelname)s: %(message)s' |
||||
logging.basicConfig(level=logging.INFO, format=_FORMAT) |
||||
|
||||
|
||||
class DistributedLogger: |
||||
"""This is a distributed event logger class essentially based on :class:`logging`. |
||||
|
||||
:param name: The name of the logger |
||||
:type name: str |
||||
:param level: The threshold for the logger. Logging messages which are less severe than `level` |
||||
will be ignored |
||||
:type level: str |
||||
:param root_path: The root path where logs are stored |
||||
:type root_path: str, optional |
||||
:param mode: The mode that the file is opened in. Defaults to 'a' |
||||
:type mode: str, optional |
||||
""" |
||||
|
||||
def __init__(self, name, level='INFO', root_path: str = None, mode='a'): |
||||
self._logger = logging.getLogger(name) |
||||
self._logger.setLevel(getattr(logging, level)) |
||||
|
||||
if root_path is not None: |
||||
log_root_path = Path(root_path) |
||||
# create path if not exists |
||||
log_root_path.mkdir(parents=True, exist_ok=True) |
||||
log_path = log_root_path.joinpath(f'{name}.log') |
||||
file_handler = logging.FileHandler(log_path, mode) |
||||
file_handler.setLevel(getattr(logging, level)) |
||||
formatter = logging.Formatter(_FORMAT) |
||||
file_handler.setFormatter(formatter) |
||||
self._logger.addHandler(file_handler) |
||||
|
||||
def _log(self, level, message: str, parallel_mode: ParallelMode = ParallelMode.GLOBAL, ranks: list = None): |
||||
if ranks is None: |
||||
getattr(self._logger, level)(message) |
||||
else: |
||||
local_rank = gpc.get_local_rank(parallel_mode) |
||||
if local_rank in ranks: |
||||
getattr(self._logger, level)(message) |
||||
|
||||
def info(self, message: str, parallel_mode: ParallelMode = ParallelMode.GLOBAL, ranks: list = None): |
||||
"""Stores an info log message. |
||||
|
||||
:param message: |
||||
:type message: |
||||
:param parallel_mode: |
||||
:type parallel_mode: |
||||
:param ranks: |
||||
:type ranks: |
||||
""" |
||||
self._log('info', message, parallel_mode, ranks) |
||||
|
||||
def warning(self, message: str, parallel_mode: ParallelMode = ParallelMode.GLOBAL, ranks: list = None): |
||||
"""Stores a warning log message. |
||||
|
||||
:param message: The message to be logged |
||||
:type message: str |
||||
:param parallel_mode: The parallel mode used for logging. Defaults to ParallelMode.GLOBAL |
||||
:type parallel_mode: :class:`colossalai.context.parallel_mode.ParallelMode` |
||||
:param ranks: List of parallel ranks |
||||
:type ranks: list |
||||
""" |
||||
self._log('warning', message, parallel_mode, ranks) |
||||
|
||||
def debug(self, message: str, parallel_mode: ParallelMode = ParallelMode.GLOBAL, ranks: list = None): |
||||
"""Stores a debug log message. |
||||
|
||||
:param message: The message to be logged |
||||
:type message: str |
||||
:param parallel_mode: The parallel mode used for logging. Defaults to ParallelMode.GLOBAL |
||||
:type parallel_mode: :class:`colossalai.context.parallel_mode.ParallelMode` |
||||
:param ranks: List of parallel ranks |
||||
:type ranks: list |
||||
""" |
||||
self._log('debug', message, parallel_mode, ranks) |
||||
|
||||
def error(self, message: str, parallel_mode: ParallelMode = ParallelMode.GLOBAL, ranks: list = None): |
||||
"""Stores an error log message. |
||||
|
||||
:param message: The message to be logged |
||||
:type message: str |
||||
:param parallel_mode: The parallel mode used for logging. Defaults to ParallelMode.GLOBAL |
||||
:type parallel_mode: :class:`colossalai.context.parallel_mode.ParallelMode` |
||||
:param ranks: List of parallel ranks |
||||
:type ranks: list |
||||
""" |
||||
self._log('error', message, parallel_mode, ranks) |
@ -0,0 +1,6 @@
|
||||
from .data import * |
||||
from .layer import * |
||||
from .loss import * |
||||
from .lr_scheduler import * |
||||
from .model import * |
||||
from .optimizer import * |
@ -0,0 +1,3 @@
|
||||
from .caltech101_dataset import Caltech101Dataset |
||||
from .cifar10_dataset import CIFAR10Dataset |
||||
from .sampler import * |
@ -0,0 +1,14 @@
|
||||
import numpy as np |
||||
|
||||
|
||||
def pil_img_to_numpy(pil_img): |
||||
"""convert a PIL image to numpy nd-array |
||||
|
||||
:param pil_img: a PIL image |
||||
:type pil_img: PIL.Image |
||||
:return: a nd-array |
||||
:rtype: numpy.ndarray |
||||
""" |
||||
np_img = np.array(pil_img) |
||||
np_img = np.rollaxis(np_img, 2) # HWC to CHW |
||||
return np_img |
@ -0,0 +1,17 @@
|
||||
#!/usr/bin/env python |
||||
# -*- encoding: utf-8 -*- |
||||
|
||||
from abc import ABC |
||||
|
||||
from torch.utils.data import Dataset |
||||
from torchvision.transforms import transforms |
||||
|
||||
from colossalai.builder import build_transform |
||||
|
||||
|
||||
class BaseDataset(Dataset, ABC): |
||||
|
||||
def __init__(self, transform_pipeline: list): |
||||
transform_list = [build_transform(cfg) for cfg in transform_pipeline] |
||||
transform = transforms.Compose(transform_list) |
||||
self._transform_pipeline = transform |
@ -0,0 +1,43 @@
|
||||
#!/usr/bin/env python |
||||
# -*- encoding: utf-8 -*- |
||||
|
||||
import torch.distributed as dist |
||||
from torchvision.datasets import Caltech101 |
||||
|
||||
from colossalai.context import ParallelMode |
||||
from colossalai.core import global_context as gpc |
||||
from colossalai.registry import DATASETS |
||||
from .base_dataset import BaseDataset |
||||
|
||||
|
||||
@DATASETS.register_module |
||||
class Caltech101Dataset(BaseDataset): |
||||
"""`Caltech 101 <http://www.vision.caltech.edu/Image_Datasets/Caltech101/>`_ Dataset. |
||||
|
||||
:param transform_pipeline: A list of functions' config, which takes in an PIL image |
||||
and returns a transformed version |
||||
:type transform_pipeline: list |
||||
""" |
||||
|
||||
def __init__(self, transform_pipeline: list, *args, **kwargs): |
||||
super().__init__(transform_pipeline) |
||||
if gpc.is_initialized(ParallelMode.GLOBAL) and gpc.get_global_rank() != 0: |
||||
dist.barrier() |
||||
self._dataset = Caltech101( |
||||
transform=self._transform_pipeline, *args, **kwargs) |
||||
if gpc.is_initialized(ParallelMode.GLOBAL) and gpc.get_global_rank() == 0: |
||||
dist.barrier() |
||||
|
||||
def __len__(self): |
||||
return len(self._dataset) |
||||
|
||||
def __getitem__(self, item): |
||||
""" |
||||
|
||||
:param item: Index |
||||
:type item: int |
||||
:return: ((image,), (target,)) where the type of target specified by target_type. |
||||
:rtype: tuple |
||||
""" |
||||
img, label = self._dataset.__getitem__(item) |
||||
return (img,), (label,) |
@ -0,0 +1,44 @@
|
||||
#!/usr/bin/env python |
||||
# -*- encoding: utf-8 -*- |
||||
|
||||
import torch.distributed as dist |
||||
from torchvision.datasets import CIFAR10 |
||||
|
||||
from colossalai.context import ParallelMode |
||||
from colossalai.core import global_context as gpc |
||||
from colossalai.registry import DATASETS |
||||
from .base_dataset import BaseDataset |
||||
|
||||
|
||||
@DATASETS.register_module |
||||
class CIFAR10Dataset(BaseDataset): |
||||
"""`CIFAR10 <https://www.cs.toronto.edu/~kriz/cifar.html>`_ Dataset. |
||||
|
||||
:param transform_pipeline: A list of functions' config, which takes in an PIL image |
||||
and returns a transformed version |
||||
:type transform_pipeline: list |
||||
""" |
||||
|
||||
def __init__(self, transform_pipeline: list, *args, **kwargs): |
||||
super().__init__(transform_pipeline) |
||||
if gpc.is_initialized(ParallelMode.GLOBAL) and gpc.get_global_rank() != 0: |
||||
dist.barrier() |
||||
self._dataset = CIFAR10(transform=self._transform_pipeline, |
||||
*args, |
||||
**kwargs) |
||||
if gpc.is_initialized(ParallelMode.GLOBAL) and gpc.get_global_rank() == 0: |
||||
dist.barrier() |
||||
|
||||
def __len__(self): |
||||
return len(self._dataset) |
||||
|
||||
def __getitem__(self, item): |
||||
""" |
||||
|
||||
:param item: Index |
||||
:type item: int |
||||
:return: ((image,), (target,)) where the type of target specified by target_type. |
||||
:rtype: tuple |
||||
""" |
||||
img, label = self._dataset.__getitem__(item) |
||||
return (img,), (label,) |
@ -0,0 +1,4 @@
|
||||
from .base_sampler import BaseSampler |
||||
from .data_parallel_sampler import DataParallelSampler |
||||
|
||||
__all__ = ['BaseSampler', 'DataParallelSampler'] |
@ -0,0 +1,19 @@
|
||||
#!/usr/bin/env python |
||||
# -*- encoding: utf-8 -*- |
||||
|
||||
from abc import ABC, abstractmethod |
||||
|
||||
|
||||
class BaseSampler(ABC): |
||||
|
||||
def __init__(self, dataset, batch_size): |
||||
self.dataset = dataset |
||||
self.batch_size = batch_size |
||||
|
||||
@abstractmethod |
||||
def __len__(self): |
||||
pass |
||||
|
||||
@abstractmethod |
||||
def __iter__(self): |
||||
pass |
@ -0,0 +1,102 @@
|
||||
#!/usr/bin/env python |
||||
# -*- encoding: utf-8 -*- |
||||
# adpated from torch.utils.data.DistributedSampler |
||||
|
||||
import math |
||||
from typing import TypeVar, Iterator |
||||
|
||||
import torch |
||||
from torch.utils.data import Sampler, Dataset |
||||
|
||||
from colossalai.context.parallel_mode import ParallelMode |
||||
from colossalai.core import global_context as gpc |
||||
from colossalai.registry import SAMPLERS |
||||
|
||||
T_co = TypeVar('T_co', covariant=True) |
||||
|
||||
|
||||
@SAMPLERS.register_module |
||||
class DataParallelSampler(Sampler): |
||||
"""A data sampler for distributed data parallelism |
||||
|
||||
:param dataset: a Dataset instance |
||||
:type dataset: torch.utils.data.Dataset |
||||
:param shuffle: whether to shuffle data, defaults to False |
||||
:type shuffle: bool, optional |
||||
:param seed: the random seed, defaults to 0 |
||||
:type seed: int, optional |
||||
:param drop_last: set to True to drop the last incomplete batch, if the dataset size is not divisible by the batch size. If False and the size of dataset is not divisible by the batch size, then the last batch will be smaller, defaults to False |
||||
:type drop_last: bool, optional |
||||
""" |
||||
|
||||
def __init__(self, |
||||
dataset: Dataset, |
||||
shuffle: bool = False, |
||||
seed: int = 0, |
||||
drop_last: bool = False) -> None: |
||||
self.dataset = dataset |
||||
self.num_replicas = gpc.get_world_size(ParallelMode.DATA) |
||||
self.rank = gpc.get_local_rank(ParallelMode.DATA) |
||||
self.epoch = 0 |
||||
self.drop_last = drop_last |
||||
# If the dataset length is evenly divisible by # of replicas, then there |
||||
# is no need to drop any data, since the dataset will be split equally. |
||||
# type: ignore[arg-type] |
||||
if self.drop_last and len(self.dataset) % self.num_replicas != 0: |
||||
# Split to nearest available length that is evenly divisible. |
||||
# This is to ensure each rank receives the same amount of data when |
||||
# using this Sampler. |
||||
self.num_samples = math.ceil( |
||||
# `type:ignore` is required because Dataset cannot provide a default __len__ |
||||
# see NOTE in pytorch/torch/utils/data/sampler.py |
||||
(len(self.dataset) - self.num_replicas) / \ |
||||
self.num_replicas # type: ignore[arg-type] |
||||
) |
||||
else: |
||||
self.num_samples = math.ceil( |
||||
len(self.dataset) / self.num_replicas) # type: ignore[arg-type] |
||||
self.total_size = self.num_samples * self.num_replicas |
||||
self.shuffle = shuffle |
||||
self.seed = seed |
||||
|
||||
def __iter__(self) -> Iterator[T_co]: |
||||
if self.shuffle: |
||||
# deterministically shuffle based on epoch and seed |
||||
g = torch.Generator() |
||||
g.manual_seed(self.seed + self.epoch) |
||||
# type: ignore[arg-type] |
||||
indices = torch.randperm(len(self.dataset), generator=g).tolist() |
||||
else: |
||||
indices = list(range(len(self.dataset))) # type: ignore[arg-type] |
||||
|
||||
if not self.drop_last: |
||||
# add extra samples to make it evenly divisible |
||||
padding_size = self.total_size - len(indices) |
||||
if padding_size <= len(indices): |
||||
indices += indices[:padding_size] |
||||
else: |
||||
indices += (indices * math.ceil(padding_size / |
||||
len(indices)))[:padding_size] |
||||
else: |
||||
# remove tail of data to make it evenly divisible. |
||||
indices = indices[:self.total_size] |
||||
assert len(indices) == self.total_size |
||||
|
||||
# subsample |
||||
indices = indices[self.rank:self.total_size:self.num_replicas] |
||||
assert len(indices) == self.num_samples |
||||
|
||||
return iter(indices) |
||||
|
||||
def __len__(self) -> int: |
||||
return self.num_samples |
||||
|
||||
def set_epoch(self, epoch: int) -> None: |
||||
r"""Sets the epoch for this sampler. When :attr:`shuffle=True`, this ensures all replicas |
||||
use a different random ordering for each epoch. Otherwise, the next iteration of this |
||||
sampler will yield the same ordering. |
||||
|
||||
:param epoch: Epoch number. |
||||
:type epoch: int |
||||
""" |
||||
self.epoch = epoch |
@ -0,0 +1,9 @@
|
||||
from .parallel_1d import * |
||||
from .parallel_2d import * |
||||
from .parallel_2p5d import * |
||||
from .parallel_3d import * |
||||
from .parallel_sequence import * |
||||
from .parallel_vision_transformer import * |
||||
from .vanilla_resnet import * |
||||
from .vanilla_vision_transformer import * |
||||
from .wrapper import * |
@ -0,0 +1,63 @@
|
||||
#!/usr/bin/env python |
||||
# -*- encoding: utf-8 -*- |
||||
|
||||
import math |
||||
|
||||
import torch |
||||
from torch import Tensor |
||||
from torch import nn |
||||
from colossalai.utils import checkpoint |
||||
|
||||
from colossalai.constants import IS_TENSOR_PARALLEL |
||||
|
||||
|
||||
def divide(numerator, denominator): |
||||
""" only allow exact division """ |
||||
assert numerator % denominator == 0, \ |
||||
'{} is not divisible by {}'.format(numerator, denominator) |
||||
return numerator // denominator |
||||
|
||||
|
||||
def gelu(x: Tensor) -> Tensor: |
||||
"""Implementation of the gelu activation function. |
||||
For information: OpenAI GPT's gelu is slightly different (and gives slightly different results): |
||||
0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) |
||||
""" |
||||
return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) |
||||
|
||||
|
||||
def swish(x: Tensor) -> Tensor: |
||||
return x * torch.sigmoid(x) |
||||
|
||||
|
||||
ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish} |
||||
|
||||
|
||||
def set_tensor_parallel_attribute(param): |
||||
if not hasattr(param, IS_TENSOR_PARALLEL): |
||||
setattr(param, IS_TENSOR_PARALLEL, True) |
||||
|
||||
|
||||
class CheckpointModule(nn.Module): |
||||
def __init__(self, checkpoint: bool = True): |
||||
super().__init__() |
||||
self.checkpoint = checkpoint |
||||
self._use_checkpoint = checkpoint |
||||
|
||||
def _forward(self, *args): |
||||
raise NotImplementedError( |
||||
'CheckpointModule should implement _forward method instead of origin forward') |
||||
|
||||
def forward(self, *args): |
||||
if self._use_checkpoint: |
||||
return checkpoint(self._forward, *args) |
||||
else: |
||||
return self._forward(*args) |
||||
|
||||
def train(self, mode: bool = True): |
||||
self._use_checkpoint = self.checkpoint |
||||
return super().train(mode=mode) |
||||
|
||||
def eval(self): |
||||
self._use_checkpoint = False |
||||
return super().eval() |
@ -0,0 +1,138 @@
|
||||
#!/usr/bin/env python |
||||
# -*- encoding: utf-8 -*- |
||||
|
||||
import torch |
||||
import torch.distributed as dist |
||||
|
||||
from colossalai.core import global_context as gpc |
||||
|
||||
|
||||
def _reduce(input_, parallel_mode): |
||||
# skip if only one rank involved |
||||
if gpc.get_world_size(parallel_mode) == 1: |
||||
return input_ |
||||
dist.all_reduce(input_, group=gpc.get_group(parallel_mode)) |
||||
|
||||
return input_ |
||||
|
||||
|
||||
def _split(input_, parallel_mode, dim=-1): |
||||
# skip if only one rank involved |
||||
world_size = gpc.get_world_size(parallel_mode) |
||||
if world_size == 1: |
||||
return input_ |
||||
|
||||
# Split along last dimension. |
||||
dim_size = input_.size(dim) |
||||
assert dim_size % world_size == 0, \ |
||||
f'The dimension to split ({dim_size}) is not a multiple of world size ({world_size}), ' \ |
||||
f'cannot split tensor evenly' |
||||
|
||||
tensor_list = torch.split(input_, dim_size // world_size, dim=dim) |
||||
rank = gpc.get_local_rank(parallel_mode) |
||||
output = tensor_list[rank].contiguous() |
||||
|
||||
return output |
||||
|
||||
|
||||
def _gather(input_, parallel_mode, dim=-1): |
||||
# skip if only one rank involved |
||||
world_size = gpc.get_world_size(parallel_mode) |
||||
if world_size == 1: |
||||
return input_ |
||||
|
||||
# all gather |
||||
rank = gpc.get_local_rank(parallel_mode) |
||||
tensor_list = [torch.empty_like(input_) for _ in range(world_size)] |
||||
tensor_list[rank] = input_ |
||||
torch.distributed.all_gather(tensor_list, input_, group=gpc.get_group(parallel_mode)) |
||||
|
||||
# concat |
||||
output = torch.cat(tensor_list, dim=dim).contiguous() |
||||
|
||||
return output |
||||
|
||||
|
||||
class _ReduceGrad(torch.autograd.Function): |
||||
"""Pass the input to the model parallel region.""" |
||||
|
||||
@staticmethod |
||||
def symbolic(graph, input_): |
||||
return input_ |
||||
|
||||
@staticmethod |
||||
def forward(ctx, input_, parallel_mode): |
||||
ctx.mode = parallel_mode |
||||
return input_ |
||||
|
||||
@staticmethod |
||||
def backward(ctx, grad_output): |
||||
return _reduce(grad_output, ctx.mode), None |
||||
|
||||
|
||||
class _ReduceInput(torch.autograd.Function): |
||||
"""All-reduce the input from the model parallel region.""" |
||||
|
||||
@staticmethod |
||||
def symbolic(graph, input_): |
||||
return _reduce(input_) |
||||
|
||||
@staticmethod |
||||
def forward(ctx, input_, parallel_mode): |
||||
return _reduce(input_, parallel_mode) |
||||
|
||||
@staticmethod |
||||
def backward(ctx, grad_output): |
||||
return grad_output, None |
||||
|
||||
|
||||
class _SplitForwardGatherBackward(torch.autograd.Function): |
||||
"""Split the input and keep only the corresponding chuck to the rank.""" |
||||
|
||||
@staticmethod |
||||
def symbolic(graph, input_): |
||||
return _split(input_) |
||||
|
||||
@staticmethod |
||||
def forward(ctx, input_, parallel_mode, dim): |
||||
ctx.mode = parallel_mode |
||||
ctx.dim = dim |
||||
return _split(input_, parallel_mode, dim) |
||||
|
||||
@staticmethod |
||||
def backward(ctx, grad_output): |
||||
return _gather(grad_output, ctx.mode, ctx.dim), None, None |
||||
|
||||
|
||||
class _GatherForwardSplitBackward(torch.autograd.Function): |
||||
"""Gather the input from model parallel region and concatinate.""" |
||||
|
||||
@staticmethod |
||||
def symbolic(graph, input_): |
||||
return _gather(input_) |
||||
|
||||
@staticmethod |
||||
def forward(ctx, input_, parallel_mode, dim): |
||||
ctx.mode = parallel_mode |
||||
ctx.dim = dim |
||||
return _gather(input_, parallel_mode, dim) |
||||
|
||||
@staticmethod |
||||
def backward(ctx, grad_output): |
||||
return _split(grad_output, ctx.mode, ctx.dim), None, None |
||||
|
||||
|
||||
def reduce_grad(input_, parallel_mode): |
||||
return _ReduceGrad.apply(input_, parallel_mode) |
||||
|
||||
|
||||
def reduce_input(input_, parallel_mode): |
||||
return _ReduceInput.apply(input_, parallel_mode) |
||||
|
||||
|
||||
def split_forward_gather_backward(input_, parallel_mode, dim): |
||||
return _SplitForwardGatherBackward.apply(input_, parallel_mode, dim) |
||||
|
||||
|
||||
def gather_forward_split_backward(input_, parallel_mode, dim): |
||||
return _GatherForwardSplitBackward.apply(input_, parallel_mode, dim) |
@ -0,0 +1,27 @@
|
||||
#!/usr/bin/env python |
||||
# -*- encoding: utf-8 -*- |
||||
|
||||
import torch.nn as nn |
||||
|
||||
from colossalai.context import ParallelMode |
||||
from colossalai.core import global_context as gpc |
||||
|
||||
|
||||
class ParallelLayer(nn.Module): |
||||
|
||||
def __init__(self): |
||||
super().__init__() |
||||
self.data_parallel_rank = 0 if not gpc.is_initialized(ParallelMode.DATA) else gpc.get_local_rank( |
||||
ParallelMode.DATA) |
||||
self.data_parallel_size = 1 if not gpc.is_initialized(ParallelMode.DATA) else gpc.get_world_size( |
||||
ParallelMode.DATA) |
||||
|
||||
self.tensor_parallel_rank = 0 if not gpc.is_initialized(ParallelMode.TENSOR) else gpc.get_local_rank( |
||||
ParallelMode.TENSOR) |
||||
self.tensor_parallel_size = 1 if not gpc.is_initialized(ParallelMode.TENSOR) else gpc.get_world_size( |
||||
ParallelMode.TENSOR) |
||||
|
||||
self.pipeline_parallel_rank = 0 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_local_rank( |
||||
ParallelMode.PIPELINE) |
||||
self.pipeline_parallel_size = 1 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_world_size( |
||||
ParallelMode.PIPELINE) |
@ -0,0 +1,5 @@
|
||||
from .layers import Linear1D_Col, Linear1D_Row |
||||
|
||||
__all__ = [ |
||||
'Linear1D_Col', 'Linear1D_Row', |
||||
] |
@ -0,0 +1,15 @@
|
||||
#!/usr/bin/env python |
||||
# -*- encoding: utf-8 -*- |
||||
|
||||
from .._common_utils import divide |
||||
|
||||
|
||||
def vocab_range_from_per_partition_vocab_size(per_partition_vocab_size, rank): |
||||
index_f = rank * per_partition_vocab_size |
||||
index_l = index_f + per_partition_vocab_size |
||||
return index_f, index_l |
||||
|
||||
|
||||
def vocab_range_from_global_vocab_size(global_vocab_size, rank, world_size): |
||||
per_partition_vocab_size = divide(global_vocab_size, world_size) |
||||
return vocab_range_from_per_partition_vocab_size(per_partition_vocab_size, rank) |
@ -0,0 +1,166 @@
|
||||
#!/usr/bin/env python |
||||
# -*- encoding: utf-8 -*- |
||||
|
||||
import torch |
||||
import torch.nn as nn |
||||
import torch.nn.functional as F |
||||
import torch.nn.init as init |
||||
from torch import Tensor |
||||
from torch.nn.parameter import Parameter |
||||
from typing import Tuple |
||||
|
||||
from colossalai.context.parallel_mode import ParallelMode |
||||
from colossalai.core import global_context as gpc |
||||
from colossalai.registry import LAYERS |
||||
from colossalai.utils import get_current_device |
||||
from .._common_utils import divide |
||||
from .._parallel_utilities import reduce_grad, reduce_input, gather_forward_split_backward, \ |
||||
split_forward_gather_backward |
||||
from ..base_layer import ParallelLayer |
||||
|
||||
|
||||
class Linear1D_Col(ParallelLayer): |
||||
"""Linear layer with column parallelism. |
||||
|
||||
The linear layer is defined as :math:`Y = XA + b`. A is parallelized along |
||||
its second dimension as :math:`A = [A_1, ..., A_p]`. |
||||
|
||||
:param in_features: first dimension of matrix A. |
||||
:type in_features: int |
||||
:param output_size: second dimension of matrix A. |
||||
:type output_size: int |
||||
:param bias: If true, add bias, defaults to True |
||||
:type bias: bool, optional |
||||
:param dtype: The dtype of parameters, defaults to None |
||||
:type dtype: torch.dtype, optional |
||||
:param gather_output: If true, call all-gether on output and make Y avaiable |
||||
to all GPUs, otherwise, every GPU will have its output |
||||
which is :math:`Y_i = XA_i`, defaults to False |
||||
:type gather_output: bool, optional |
||||
""" |
||||
|
||||
def __init__(self, |
||||
in_features: int, |
||||
output_size: int, |
||||
bias: bool = True, |
||||
dtype: torch.dtype = None, |
||||
gather_output: bool = False): |
||||
super().__init__() |
||||
|
||||
# Keep input parameters |
||||
self.input_size = in_features |
||||
self.output_size = output_size |
||||
self.gather_output = gather_output |
||||
self.skip_bias_add = not bias |
||||
|
||||
world_size = gpc.get_world_size(ParallelMode.PARALLEL_1D) |
||||
self.output_size_per_partition = divide(output_size, world_size) |
||||
|
||||
# Parameters. |
||||
# Initialize weight. |
||||
factory_kwargs = {'device': get_current_device(), 'dtype': dtype} |
||||
self.weight = Parameter(torch.empty( |
||||
self.output_size_per_partition, self.input_size, |
||||
**factory_kwargs)) |
||||
|
||||
if bias: |
||||
self.bias = Parameter(torch.empty( |
||||
self.output_size_per_partition, |
||||
**factory_kwargs)) |
||||
# Always initialize bias to zero. |
||||
with torch.no_grad(): |
||||
self.bias.zero_() |
||||
else: |
||||
self.register_parameter('bias', None) |
||||
|
||||
def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]: |
||||
# Set up backprop all-reduce. |
||||
input_parallel = reduce_grad(input_, ParallelMode.PARALLEL_1D) |
||||
# Matrix multiply. |
||||
|
||||
bias = self.bias if not self.skip_bias_add else None |
||||
output_parallel = F.linear(input_parallel, self.weight, bias) |
||||
if self.gather_output: |
||||
# All-gather across the partitions. |
||||
output = gather_forward_split_backward( |
||||
output_parallel, ParallelMode.PARALLEL_1D, dim=-1) |
||||
else: |
||||
output = output_parallel |
||||
if self.skip_bias_add: |
||||
return output, self.bias |
||||
else: |
||||
return output |
||||
|
||||
|
||||
@LAYERS.register_module |
||||
class Linear1D_Row(ParallelLayer): |
||||
""" Linear layer with row parallelism |
||||
|
||||
:param in_features: size of each input sample |
||||
:type in_features: int |
||||
:param out_features: size of each output sample |
||||
:type out_features: int |
||||
:param bias: If set to ``False``, the layer will not learn an additive bias, defaults to True |
||||
:type bias: bool, optional |
||||
:param dtype: The dtype of parameters, defaults to None |
||||
:type dtype: torch.dtype, optional |
||||
:param parallel_input: If set to ``False``, it's assumed that the input is splitted, defaults to False |
||||
:type parallel_input: bool, optional |
||||
""" |
||||
|
||||
def __init__(self, |
||||
in_features: int, |
||||
out_features: int, |
||||
bias: bool = True, |
||||
dtype: torch.dtype = None, |
||||
parallel_input: bool = False |
||||
): |
||||
super().__init__() |
||||
|
||||
# Keep input parameters |
||||
self.in_features = in_features |
||||
self.out_features = out_features |
||||
self.parallel_input = parallel_input |
||||
self.skip_bias_add = not bias |
||||
|
||||
# Divide the weight matrix along the last dimension. |
||||
world_size = gpc.get_world_size(ParallelMode.PARALLEL_1D) |
||||
self.input_size_per_partition = divide(in_features, world_size) |
||||
|
||||
# Parameters. |
||||
# Initialize weight. |
||||
factory_kwargs = {'device': get_current_device(), 'dtype': dtype} |
||||
self.weight = Parameter(torch.empty( |
||||
self.out_features, |
||||
self.input_size_per_partition, |
||||
**factory_kwargs)) |
||||
|
||||
if bias: |
||||
self.bias = Parameter(torch.empty( |
||||
self.out_features, |
||||
**factory_kwargs |
||||
)) |
||||
|
||||
# Always initialize bias to zero. |
||||
with torch.no_grad(): |
||||
self.bias.zero_() |
||||
else: |
||||
self.register_parameter('bias', None) |
||||
|
||||
def reset_parameters(self) -> None: |
||||
init.xavier_normal_(self.weight) |
||||
|
||||
def forward(self, input_: Tensor) -> Tensor: |
||||
# Set up backprop all-reduce. |
||||
if self.parallel_input: |
||||
input_ = input_ |
||||
else: |
||||
input_ = split_forward_gather_backward( |
||||
input_, ParallelMode.PARALLEL_1D, dim=-1) |
||||
|
||||
output_parallel = F.linear(input_, self.weight) |
||||
output = reduce_input(output_parallel, ParallelMode.PARALLEL_1D) |
||||
|
||||
if not self.skip_bias_add: |
||||
output = output + self.bias |
||||
return output |
@ -0,0 +1,11 @@
|
||||
from ._operation import Matmul_AB_2D, Matmul_ABT_2D, Matmul_ATB_2D, Add_Bias_2D, matmul_2d |
||||
from ._transformer import TransformerMLP2D, TransformerSelfAttention2D, TransformerLayer2D |
||||
from ._vit import ViTMLP2D, ViTSelfAttention2D, ViTHead2D, ViTPatchEmbedding2D, ViTTokenFuser2D, ViTInputSplitter2D |
||||
from .layers import Linear2D, LayerNorm2D |
||||
|
||||
__all__ = [ |
||||
'Matmul_AB_2D', 'Matmul_ABT_2D', 'Matmul_ATB_2D', 'Add_Bias_2D', 'matmul_2d', |
||||
'TransformerMLP2D', 'TransformerSelfAttention2D', 'TransformerLayer2D', |
||||
'ViTMLP2D', 'ViTSelfAttention2D', 'ViTHead2D', 'ViTPatchEmbedding2D', 'ViTTokenFuser2D', 'ViTInputSplitter2D', |
||||
'Linear2D', 'LayerNorm2D' |
||||
] |
@ -0,0 +1,522 @@
|
||||
from typing import Any, Tuple |
||||
|
||||
import torch |
||||
import torch.distributed as dist |
||||
from torch import Tensor |
||||
|
||||
from colossalai.context.parallel_mode import ParallelMode |
||||
from colossalai.core import global_context as gpc |
||||
from colossalai.utils import get_current_device |
||||
|
||||
|
||||
def matmul_2d(a, |
||||
b, |
||||
summa_dim, |
||||
out_shape, |
||||
row_rank=None, |
||||
col_rank=None, |
||||
row_parallel_mode=ParallelMode.PARALLEL_2D_ROW, |
||||
col_parallel_mode=ParallelMode.PARALLEL_2D_COL, |
||||
): |
||||
"""Matrix multiplication for 2D parallelism |
||||
|
||||
:param a: matrix :math:`A` |
||||
:type a: torch.tensor |
||||
:param b: matrix :math:`B` |
||||
:type b: torch.tensor |
||||
:param summa_dim: dimension of SUMMA fo 2D parallelism |
||||
:type summa_dim: int |
||||
:param out_shape: shape of output tensor |
||||
:type out_shape: tuple |
||||
:param row_rank: the rank of row, defaults to None |
||||
:type row_rank: int, optional |
||||
:param col_rank: the rank of column, defaults to None |
||||
:type col_rank: int, optional |
||||
:param row_parallel_mode: row parallel mode, defaults to ParallelMode.PARALLEL_2D_ROW |
||||
:type row_parallel_mode: str, optional |
||||
:param col_parallel_mode: column parallel mode, defaults to ParallelMode.PARALLEL_2D_COL |
||||
:type col_parallel_mode: str, optional |
||||
:return: :math:`C = AB` |
||||
:rtype: torch.tensor |
||||
""" |
||||
if row_rank is None: |
||||
row_rank = gpc.get_local_rank(col_parallel_mode) |
||||
if col_rank is None: |
||||
col_rank = gpc.get_local_rank(row_parallel_mode) |
||||
|
||||
data_parallel_rank = 0 if not gpc.is_initialized( |
||||
ParallelMode.DATA) else gpc.get_local_rank(ParallelMode.DATA) |
||||
pipeline_parallel_rank = 0 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_local_rank( |
||||
ParallelMode.PIPELINE) |
||||
pipeline_parallel_size = 1 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_world_size( |
||||
ParallelMode.PIPELINE) |
||||
tensor_parallel_size = summa_dim ** 2 |
||||
return Matmul_AB_2D(a, b, summa_dim, out_shape, row_rank, col_rank, row_parallel_mode, col_parallel_mode, |
||||
data_parallel_rank, pipeline_parallel_rank, pipeline_parallel_size, tensor_parallel_size |
||||
) |
||||
|
||||
|
||||
class Matmul_AB_2D(torch.autograd.Function): |
||||
"""Matrix multiplication for :math:`C = AB` |
||||
""" |
||||
@staticmethod |
||||
def forward(ctx: Any, |
||||
A: Tensor, |
||||
B: Tensor, |
||||
summa_dim: int, |
||||
out_shape: Tuple[int, ...], |
||||
row_rank: int, |
||||
col_rank: int, |
||||
row_parallel_mode: ParallelMode, |
||||
col_parallel_mode: ParallelMode, |
||||
data_parallel_rank: int, |
||||
pipeline_parallel_rank: int, |
||||
pipeline_parallel_size: int, |
||||
tensor_parallel_size: int) -> Tensor: |
||||
# A: [b / q, s, h / q] -> [(b * s) / q, h / q] |
||||
# B: [h / q, s / q] |
||||
# C: [b / q, s, s / q] -> [(b * s) / q, s / q] |
||||
|
||||
assert A.shape[-1] == B.shape[-2], \ |
||||
'Invalid shapes: A={}, B={} for AB.'.format(A.shape, B.shape) |
||||
|
||||
if ctx: |
||||
ctx.save_for_backward(A, B) |
||||
|
||||
A_shape = A.shape |
||||
A = A.reshape((-1, A_shape[-1])) |
||||
B_shape = B.shape |
||||
B = B.reshape((-1, B_shape[-1])) |
||||
C_shape = (A.shape[0], B.shape[-1]) |
||||
C = torch.zeros(C_shape, dtype=A.dtype, device=get_current_device()) |
||||
|
||||
for i in range(summa_dim): |
||||
A_temp = A.clone() |
||||
B_temp = B.clone() |
||||
src_a = i + summa_dim * row_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \ |
||||
pipeline_parallel_rank * tensor_parallel_size |
||||
dist.broadcast(A_temp, src=src_a, |
||||
group=gpc.get_group(row_parallel_mode)) |
||||
src_b = col_rank + summa_dim * i + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \ |
||||
pipeline_parallel_rank * tensor_parallel_size |
||||
dist.broadcast(B_temp, src=src_b, |
||||
group=gpc.get_group(col_parallel_mode)) |
||||
torch.addmm(C, A_temp, B_temp, out=C) |
||||
|
||||
out = C.reshape(out_shape) |
||||
|
||||
if ctx: |
||||
ctx.summa_dim = summa_dim |
||||
ctx.row_rank = row_rank |
||||
ctx.col_rank = col_rank |
||||
ctx.row_parallel_mode = row_parallel_mode |
||||
ctx.col_parallel_mode = col_parallel_mode |
||||
ctx.A_shape = A_shape |
||||
ctx.B_shape = B_shape |
||||
ctx.data_parallel_rank = data_parallel_rank |
||||
ctx.pipeline_parallel_rank = pipeline_parallel_rank |
||||
ctx.pipeline_parallel_size = pipeline_parallel_size |
||||
ctx.tensor_parallel_size = tensor_parallel_size |
||||
return out |
||||
|
||||
@staticmethod |
||||
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]: |
||||
A, B = ctx.saved_tensors |
||||
A_grad = Matmul_ABT_2D.forward( |
||||
None, |
||||
output_grad, B, |
||||
ctx.summa_dim, ctx.A_shape, |
||||
ctx.row_rank, ctx.col_rank, |
||||
ctx.row_parallel_mode, |
||||
ctx.col_parallel_mode, |
||||
ctx.data_parallel_rank, |
||||
ctx.pipeline_parallel_rank, |
||||
ctx.pipeline_parallel_size, |
||||
ctx.tensor_parallel_size |
||||
) |
||||
B_grad = Matmul_ATB_2D.forward( |
||||
None, |
||||
A, output_grad, |
||||
ctx.summa_dim, ctx.B_shape, |
||||
ctx.row_rank, ctx.col_rank, |
||||
ctx.row_parallel_mode, |
||||
ctx.col_parallel_mode, |
||||
ctx.data_parallel_rank, |
||||
ctx.pipeline_parallel_rank, |
||||
ctx.pipeline_parallel_size, |
||||
ctx.tensor_parallel_size |
||||
) |
||||
return A_grad, B_grad, None, None, None, None, None, None, None, None, None, None |
||||
|
||||
|
||||
class Matmul_ABT_2D(torch.autograd.Function): |
||||
"""Matrix multiplication for :math:`C = AB^T` |
||||
""" |
||||
@staticmethod |
||||
def forward(ctx: Any, |
||||
A: Tensor, |
||||
B: Tensor, |
||||
summa_dim: int, |
||||
out_shape: Tuple[int, ...], |
||||
row_rank: int, |
||||
col_rank: int, |
||||
row_parallel_mode: ParallelMode, |
||||
col_parallel_mode: ParallelMode, |
||||
data_parallel_rank: int, |
||||
pipeline_parallel_rank: int, |
||||
pipeline_parallel_size: int, |
||||
tensor_parallel_size: int |
||||
) -> Tensor: |
||||
|
||||
assert A.shape[-1] == B.shape[-1], \ |
||||
'Invalid shapes: A={}, B={} for ABT.'.format(A.shape, B.shape) |
||||
|
||||
if ctx: |
||||
ctx.save_for_backward(A, B) |
||||
|
||||
A_shape = A.shape |
||||
A = A.reshape((-1, A_shape[-1])) |
||||
B_shape = B.shape |
||||
B = B.reshape((-1, B_shape[-1])) |
||||
C_shape = (A.shape[0], B.shape[0]) |
||||
C = torch.empty(C_shape, dtype=A.dtype, device=get_current_device()) |
||||
|
||||
for i in range(summa_dim): |
||||
B_temp = B.clone() |
||||
# C_temp = torch.zeros(C_shape, dtype=C.dtype, device=get_current_device()) |
||||
src_b = col_rank + summa_dim * i + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \ |
||||
pipeline_parallel_rank * tensor_parallel_size |
||||
dist.broadcast(B_temp, src=src_b, |
||||
group=gpc.get_group(col_parallel_mode)) |
||||
C_temp = torch.matmul(A, B_temp.transpose(0, 1)) |
||||
src_c = i + summa_dim * row_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \ |
||||
pipeline_parallel_rank * tensor_parallel_size |
||||
dist.reduce(C_temp, dst=src_c, |
||||
group=gpc.get_group(row_parallel_mode)) |
||||
if i == col_rank: |
||||
C = C_temp.clone() |
||||
|
||||
out = C.reshape(out_shape) |
||||
|
||||
if ctx: |
||||
ctx.summa_dim = summa_dim |
||||
ctx.row_rank = row_rank |
||||
ctx.col_rank = col_rank |
||||
ctx.row_parallel_mode = row_parallel_mode |
||||
ctx.col_parallel_mode = col_parallel_mode |
||||
ctx.A_shape = A_shape |
||||
ctx.B_shape = B_shape |
||||
ctx.data_parallel_rank = data_parallel_rank |
||||
ctx.pipeline_parallel_rank = pipeline_parallel_rank |
||||
ctx.pipeline_parallel_size = pipeline_parallel_size |
||||
ctx.tensor_parallel_size = tensor_parallel_size |
||||
|
||||
return out |
||||
|
||||
@staticmethod |
||||
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]: |
||||
A, B = ctx.saved_tensors |
||||
A_grad = Matmul_AB_2D.forward( |
||||
None, |
||||
output_grad, B, |
||||
ctx.summa_dim, ctx.A_shape, |
||||
ctx.row_rank, ctx.col_rank, |
||||
ctx.row_parallel_mode, |
||||
ctx.col_parallel_mode, |
||||
ctx.data_parallel_rank, |
||||
ctx.pipeline_parallel_rank, |
||||
ctx.pipeline_parallel_size, |
||||
ctx.tensor_parallel_size |
||||
) |
||||
B_grad = Matmul_ATB_2D.forward( |
||||
None, |
||||
output_grad, A, |
||||
ctx.summa_dim, ctx.B_shape, |
||||
ctx.row_rank, ctx.col_rank, |
||||
ctx.row_parallel_mode, |
||||
ctx.col_parallel_mode, |
||||
ctx.data_parallel_rank, |
||||
ctx.pipeline_parallel_rank, |
||||
ctx.pipeline_parallel_size, |
||||
ctx.tensor_parallel_size |
||||
) |
||||
return A_grad, B_grad, None, None, None, None, None, None, None, None, None, None |
||||
|
||||
|
||||
class Matmul_ATB_2D(torch.autograd.Function): |
||||
"""Matrix multiplication for :math:`C = A^TB` |
||||
""" |
||||
@staticmethod |
||||
def forward(ctx: Any, |
||||
A: Tensor, |
||||
B: Tensor, |
||||
summa_dim: int, |
||||
out_shape: Tuple[int, ...], |
||||
row_rank: int, |
||||
col_rank: int, |
||||
row_parallel_mode: ParallelMode, |
||||
col_parallel_mode: ParallelMode, |
||||
data_parallel_rank: int, |
||||
pipeline_parallel_rank: int, |
||||
pipeline_parallel_size: int, |
||||
tensor_parallel_size: int |
||||
) -> Tensor: |
||||
|
||||
assert A.shape[-2] == B.shape[-2], \ |
||||
'Invalid shapes: A={}, B={} for ATB.'.format(A.shape, B.shape) |
||||
|
||||
if ctx: |
||||
ctx.save_for_backward(A, B) |
||||
|
||||
A_shape = A.shape |
||||
A = A.reshape((-1, A_shape[-1])) |
||||
B_shape = B.shape |
||||
B = B.reshape((-1, B_shape[-1])) |
||||
C_shape = (A.shape[-1], B.shape[-1]) |
||||
C = torch.empty(C_shape, dtype=A.dtype, device=get_current_device()) |
||||
|
||||
for i in range(summa_dim): |
||||
A_temp = A.clone() |
||||
# C_temp = torch.zeros(C_shape, dtype=C.dtype, device=get_current_device()) |
||||
src_a = i + summa_dim * row_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \ |
||||
pipeline_parallel_rank * tensor_parallel_size |
||||
dist.broadcast(A_temp, src=src_a, |
||||
group=gpc.get_group(row_parallel_mode)) |
||||
C_temp = torch.matmul(A_temp.transpose(0, 1), B) |
||||
src_c = col_rank + summa_dim * i + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \ |
||||
pipeline_parallel_rank * tensor_parallel_size |
||||
dist.reduce(C_temp, dst=src_c, |
||||
group=gpc.get_group(col_parallel_mode)) |
||||
if i == row_rank: |
||||
C = C_temp.clone() |
||||
|
||||
out = C.reshape(out_shape) |
||||
|
||||
if ctx: |
||||
ctx.summa_dim = summa_dim |
||||
ctx.row_rank = row_rank |
||||
ctx.col_rank = col_rank |
||||
ctx.row_parallel_mode = row_parallel_mode |
||||
ctx.col_parallel_mode = col_parallel_mode |
||||
ctx.A_shape = A_shape |
||||
ctx.B_shape = B_shape |
||||
ctx.data_parallel_rank = data_parallel_rank |
||||
ctx.pipeline_parallel_rank = pipeline_parallel_rank |
||||
ctx.pipeline_parallel_size = pipeline_parallel_size |
||||
ctx.tensor_parallel_size = tensor_parallel_size |
||||
|
||||
return out |
||||
|
||||
@staticmethod |
||||
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]: |
||||
A, B = ctx.saved_tensors |
||||
A_grad = Matmul_ABT_2D.forward( |
||||
None, |
||||
B, output_grad, |
||||
ctx.summa_dim, ctx.A_shape, |
||||
ctx.row_rank, ctx.col_rank, |
||||
ctx.row_parallel_mode, |
||||
ctx.col_parallel_mode, |
||||
ctx.data_parallel_rank, |
||||
ctx.pipeline_parallel_rank, |
||||
ctx.pipeline_parallel_size, |
||||
ctx.tensor_parallel_size |
||||
) |
||||
B_grad = Matmul_AB_2D.forward( |
||||
None, |
||||
A, output_grad, |
||||
ctx.summa_dim, ctx.B_shape, |
||||
ctx.row_rank, ctx.col_rank, |
||||
ctx.row_parallel_mode, |
||||
ctx.col_parallel_mode, |
||||
ctx.data_parallel_rank, |
||||
ctx.pipeline_parallel_rank, |
||||
ctx.pipeline_parallel_size, |
||||
ctx.tensor_parallel_size |
||||
) |
||||
return A_grad, B_grad, None, None, None, None, None, None, None, None, None, None |
||||
|
||||
|
||||
class Add_Bias_2D(torch.autograd.Function): |
||||
"""Matrix add bias: :math:`C = A + b` |
||||
""" |
||||
@staticmethod |
||||
def forward(ctx: Any, |
||||
input: Tensor, |
||||
bias: Tensor, |
||||
output_size_per_partition: int, |
||||
row_rank: int, |
||||
col_rank: int, |
||||
row_parallel_mode: ParallelMode, |
||||
col_parallel_mode: ParallelMode, |
||||
skip_bias_add: bool, |
||||
data_parallel_rank: int, |
||||
pipeline_parallel_rank: int, |
||||
pipeline_parallel_size: int, |
||||
tensor_parallel_size: int |
||||
) -> Tensor: |
||||
if row_rank == 0: |
||||
bias_temp = bias.clone() |
||||
else: |
||||
bias_temp = torch.zeros( |
||||
output_size_per_partition, |
||||
dtype=bias.dtype, |
||||
device=get_current_device()) |
||||
src_rank = col_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \ |
||||
pipeline_parallel_rank * tensor_parallel_size |
||||
dist.broadcast(bias_temp, src=src_rank, |
||||
group=gpc.get_group(col_parallel_mode)) |
||||
|
||||
ctx.row_rank = row_rank |
||||
ctx.col_rank = col_rank |
||||
ctx.row_parallel_mode = row_parallel_mode |
||||
ctx.col_parallel_mode = col_parallel_mode |
||||
ctx.bias = skip_bias_add |
||||
ctx.data_parallel_rank = data_parallel_rank |
||||
ctx.pipeline_parallel_rank = pipeline_parallel_rank |
||||
ctx.pipeline_parallel_size = pipeline_parallel_size |
||||
ctx.tensor_parallel_size = tensor_parallel_size |
||||
|
||||
if skip_bias_add: |
||||
return bias_temp |
||||
else: |
||||
output = input + bias_temp |
||||
return output |
||||
|
||||
@staticmethod |
||||
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]: |
||||
row_rank = ctx.row_rank |
||||
col_rank = ctx.col_rank |
||||
row_parallel_mode = ctx.row_parallel_mode |
||||
col_parallel_mode = ctx.col_parallel_mode |
||||
data_parallel_rank = ctx.data_parallel_rank |
||||
pipeline_parallel_rank = ctx.pipeline_parallel_rank |
||||
pipeline_parallel_size = ctx.pipeline_parallel_size |
||||
tensor_parallel_size = ctx.tensor_parallel_size |
||||
|
||||
if ctx.bias: |
||||
dst_rank = col_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \ |
||||
pipeline_parallel_rank * tensor_parallel_size |
||||
dist.reduce(output_grad, dst=dst_rank, |
||||
group=gpc.get_group(col_parallel_mode)) |
||||
if row_rank == 0: |
||||
return None, output_grad, None, None, None, None, None, None, None, None, None, None |
||||
else: |
||||
# for compatibility with zero optimizer, no grad should be None |
||||
grad_tmp = torch.zeros_like(output_grad) |
||||
return None, grad_tmp, None, None, None, None, None, None, None, None, None, None |
||||
else: |
||||
reduce_dim = tuple(range(output_grad.ndim - 1)) |
||||
reduce = torch.sum(output_grad, dim=reduce_dim) |
||||
dst_rank = col_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \ |
||||
pipeline_parallel_rank * tensor_parallel_size |
||||
dist.reduce(reduce, dst=dst_rank, |
||||
group=gpc.get_group(col_parallel_mode)) |
||||
if row_rank == 0: |
||||
return output_grad, reduce, None, None, None, None, None, None, None, None, None, None |
||||
else: |
||||
# for compatibility with zero optimizer, no grad should be None |
||||
reduce_tmp = torch.zeros_like(reduce) |
||||
return output_grad, reduce_tmp, None, None, None, None, None, None, None, None, None, None |
||||
|
||||
|
||||
class _LayerNorm_2D(torch.autograd.Function): |
||||
|
||||
@staticmethod |
||||
def forward(ctx: Any, |
||||
input: Tensor, |
||||
E_x: Tensor, |
||||
Var_x: Tensor, |
||||
hidden_size: int, |
||||
row_parallel_mode: ParallelMode, |
||||
col_parallel_mode: ParallelMode) -> Tensor: |
||||
input = input - E_x |
||||
# in here, input = x - E[x], Var_x = 1 / sqrt(Var[x] + eps) |
||||
ctx.normalized_shape = hidden_size |
||||
output = input * Var_x |
||||
ctx.save_for_backward(output, Var_x) |
||||
ctx.row_parallel_mode = row_parallel_mode |
||||
ctx.col_parallel_mode = col_parallel_mode |
||||
return output |
||||
|
||||
@staticmethod |
||||
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]: |
||||
row_parallel_mode = ctx.row_parallel_mode |
||||
col_parallel_mode = ctx.col_parallel_mode |
||||
x, Var_x = ctx.saved_tensors |
||||
# in here, Var_x = 1 / sqrt(Var[x] + eps), x = (x - E[x]) * Var_x |
||||
output_grad_sum = torch.sum(output_grad, dim=-1, keepdim=True) |
||||
torch.distributed.all_reduce( |
||||
output_grad_sum, group=gpc.get_group(row_parallel_mode)) |
||||
output_grad_sum /= ctx.normalized_shape |
||||
|
||||
output_grad_mul_x_sum = torch.sum( |
||||
output_grad * x, dim=-1, keepdim=True) |
||||
torch.distributed.all_reduce( |
||||
output_grad_mul_x_sum, group=gpc.get_group(row_parallel_mode)) |
||||
output_grad_mul_x_sum /= ctx.normalized_shape |
||||
|
||||
input_grad = output_grad.clone() |
||||
input_grad -= x * output_grad_mul_x_sum |
||||
input_grad -= output_grad_sum |
||||
input_grad *= Var_x |
||||
|
||||
return input_grad, None, None, None, None, None |
||||
|
||||
|
||||
# class Sum_2D(torch.autograd.Function): |
||||
# |
||||
# @staticmethod |
||||
# def forward(ctx: Any, |
||||
# inputs: Tensor, |
||||
# dim: int, |
||||
# summa_dim: int, |
||||
# row_parallel_mode: ParallelMode, |
||||
# keepdim: bool = False) -> Tensor: |
||||
# # input: [b/q, s, h/q] |
||||
# empty_cache() |
||||
# ctx.save_for_backward(inputs) |
||||
# # sum: [b/q, s] |
||||
# out = torch.sum(inputs, dim=dim, keepdim=keepdim) |
||||
# torch.distributed.all_reduce(out, group=gpc.get_group(row_parallel_mode)) |
||||
# return out |
||||
# |
||||
# @staticmethod |
||||
# def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]: |
||||
# with torch.no_grad(): |
||||
# inputs = ctx.saved_tensors |
||||
# input_grad = torch.ones(inputs.shape, dtype=output_grad.dtype) |
||||
# return input_grad, None, None, None, None, None |
||||
|
||||
|
||||
class _ViT_Split_Input_2D(torch.autograd.Function): |
||||
|
||||
@staticmethod |
||||
def forward(ctx: Any, |
||||
inputs: Tensor, |
||||
batch_size: int, |
||||
summa_dim: int, |
||||
col_parallel_mode: ParallelMode) -> Tensor: |
||||
# inputs: [b, s, h/q] |
||||
# output: [b/q, s, h/q] |
||||
|
||||
ctx.BATCH_SIZE = batch_size |
||||
ctx.summa_dim = summa_dim |
||||
ctx.col_parallel_mode = col_parallel_mode |
||||
row_rank = gpc.get_local_rank(col_parallel_mode) |
||||
output = torch.chunk(inputs, summa_dim, dim=0)[row_rank] |
||||
output = output.clone() |
||||
return output |
||||
|
||||
@staticmethod |
||||
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]: |
||||
# output_grad: [b/q, s, h/q] |
||||
# grads: [b, s, h/q] |
||||
grads_shape = (ctx.BATCH_SIZE,) + output_grad.shape[1:] |
||||
grads = torch.empty(grads_shape, |
||||
dtype=output_grad.dtype, |
||||
device=get_current_device()) |
||||
dist.all_gather(list(grads.chunk(ctx.summa_dim, dim=0)), |
||||
output_grad.contiguous(), |
||||
group=gpc.get_group(ctx.col_parallel_mode)) |
||||
return grads, None, None, None |
@ -0,0 +1,220 @@
|
||||
#!/usr/bin/env python |
||||
# -*- encoding: utf-8 -*- |
||||
|
||||
import math |
||||
|
||||
import torch |
||||
from torch import nn as nn, Tensor |
||||
|
||||
from colossalai.nn.layer._common_utils import divide, ACT2FN |
||||
from colossalai.nn.layer.parallel_2d._utils import assert_summa_initialization, get_summa_dim_from_env |
||||
from colossalai.registry import LAYERS |
||||
from .layers import Linear2D, LayerNorm2D |
||||
from ..base_layer import ParallelLayer |
||||
|
||||
|
||||
@LAYERS.register_module |
||||
class TransformerMLP2D(ParallelLayer): |
||||
""" |
||||
MLP will take the input with h hidden state, project it to mlp_ratio * h |
||||
hidden dimension, perform nonlinear transformation, and project the |
||||
state back into h hidden dimension. At the end, dropout is also |
||||
applied. |
||||
|
||||
:param in_features: the size of input tensor |
||||
:type in_features: int |
||||
:param mlp_ratio: hidden size of MLP divided by embedding dim, defaults to 4.0 |
||||
:type mlp_ratio: int, optional |
||||
:param act_func: activation function, defaults to 'gelu' |
||||
:type act_func: str, optional |
||||
:param dropout_prob: dropout probability, defaults to 0. |
||||
:type dropout_prob: float, optional |
||||
:param dtype: dtype of parameters, defaults to None |
||||
:type dtype: torch.dtype, optional |
||||
:param skip_bias_add: If set to ``True``, it will skip bias add for linear layer, which is preserved for kernel fusion, defaults to False |
||||
:type skip_bias_add: bool, optional |
||||
""" |
||||
|
||||
def __init__(self, |
||||
in_features: int, |
||||
mlp_ratio: int = 4.0, |
||||
act_func: str = 'gelu', |
||||
dropout_prob: float = 0., |
||||
dtype=None, |
||||
skip_bias_add: bool = False |
||||
): |
||||
super().__init__() |
||||
assert_summa_initialization() |
||||
self.summa_dim = get_summa_dim_from_env() |
||||
self.in_features = in_features |
||||
self.skip_bias_add = skip_bias_add |
||||
|
||||
# Project to h * mlp_ratio. |
||||
self.dense_1 = Linear2D( |
||||
in_features, |
||||
int(mlp_ratio * in_features), |
||||
dtype=dtype, |
||||
skip_bias_add=self.skip_bias_add |
||||
) |
||||
|
||||
assert act_func in ACT2FN.keys(), f'Invalid value for argument act_func, ' \ |
||||
f'activation function can only be {list(ACT2FN.keys())}' |
||||
self.activation_func = ACT2FN[act_func] |
||||
|
||||
# Project back to h. |
||||
self.dense_2 = Linear2D( |
||||
int(mlp_ratio * in_features), |
||||
in_features, |
||||
dtype=dtype, |
||||
skip_bias_add=self.skip_bias_add |
||||
) |
||||
self.dropout = nn.Dropout(dropout_prob) |
||||
self.layernorm = LayerNorm2D(in_features, dtype=dtype) |
||||
|
||||
def forward(self, x: Tensor) -> Tensor: |
||||
if self.skip_bias_add: |
||||
intermediate_output, _ = self.dense_1(x) |
||||
else: |
||||
intermediate_output = self.dense_1(x) |
||||
|
||||
intermediate_output = self.activation_func(intermediate_output) |
||||
|
||||
if self.skip_bias_add: |
||||
output, _ = self.dense_2(intermediate_output) |
||||
else: |
||||
output = self.dense_2(intermediate_output) |
||||
|
||||
output = self.dropout(output) |
||||
output = self.layernorm(x + output) |
||||
return output |
||||
|
||||
|
||||
@LAYERS.register_module |
||||
class TransformerSelfAttention2D(ParallelLayer): |
||||
"""Self attention layer for 2D parallel Transformer |
||||
|
||||
:param hidden_size: hidden size |
||||
:type hidden_size: int |
||||
:param num_attention_heads: number of attention heads |
||||
:type num_attention_heads: int |
||||
:param attention_dropout_prob: dropout probability for attention layer |
||||
:type attention_dropout_prob: float |
||||
:param hidden_dropout_prob: dropout probability for hidden layer |
||||
:type hidden_dropout_prob: float |
||||
:param dtype: dtype of parameters, defaults to None |
||||
:type dtype: torch.dtype, optional |
||||
""" |
||||
|
||||
def __init__(self, |
||||
hidden_size: int, |
||||
num_attention_heads: int, |
||||
attention_dropout_prob: float, |
||||
hidden_dropout_prob: float, |
||||
dtype=None, |
||||
): |
||||
|
||||
super().__init__() |
||||
|
||||
assert_summa_initialization() |
||||
self.summa_dim = get_summa_dim_from_env() |
||||
self.hidden_size = hidden_size |
||||
self.num_attention_heads = divide(num_attention_heads, self.summa_dim) |
||||
self.attention_head_size = divide(hidden_size, num_attention_heads) |
||||
self.all_head_size = self.num_attention_heads * self.attention_head_size |
||||
|
||||
self.query_key_value = Linear2D( |
||||
hidden_size, |
||||
3 * hidden_size, |
||||
dtype=dtype, |
||||
) |
||||
self.attention_dropout = nn.Dropout(attention_dropout_prob) |
||||
self.dense = Linear2D( |
||||
hidden_size, |
||||
hidden_size, |
||||
dtype=dtype, |
||||
) |
||||
self.dropout = nn.Dropout(hidden_dropout_prob) |
||||
self.layernorm = LayerNorm2D( |
||||
hidden_size, |
||||
dtype=dtype) |
||||
|
||||
def forward(self, hidden_states: Tensor, attention_mask: Tensor) -> Tensor: |
||||
query_key_value = self.query_key_value(hidden_states) |
||||
new_qkv_shape = query_key_value.shape[:-1] + \ |
||||
(self.num_attention_heads, 3 * self.attention_head_size) |
||||
query_key_value = query_key_value.view(new_qkv_shape) |
||||
query_key_value = query_key_value.permute((0, 2, 1, 3)) |
||||
query_layer, key_layer, value_layer = torch.chunk( |
||||
query_key_value, 3, dim=-1) |
||||
|
||||
attention_scores = torch.matmul( |
||||
query_layer, key_layer.transpose(-1, -2)) |
||||
attention_scores = attention_scores / \ |
||||
math.sqrt(self.attention_head_size) |
||||
attention_scores = attention_scores + attention_mask |
||||
attention_probs = nn.Softmax(dim=-1)(attention_scores) |
||||
attention_probs = self.attention_dropout(attention_probs) |
||||
|
||||
context_layer = torch.matmul(attention_probs, value_layer) |
||||
context_layer = context_layer.permute((0, 2, 1, 3)).contiguous() |
||||
new_context_layer_shape = context_layer.size()[ |
||||
:-2] + (self.all_head_size,) |
||||
context_layer = context_layer.view(*new_context_layer_shape) |
||||
|
||||
output = self.dense(context_layer) |
||||
output = self.dropout(output) |
||||
attention_output = self.layernorm(hidden_states + output) |
||||
|
||||
return attention_output |
||||
|
||||
|
||||
@LAYERS.register_module |
||||
class TransformerLayer2D(ParallelLayer): |
||||
"""Transformer layer which contains a self-attention layer and a MLP layer |
||||
|
||||
:param hidden_size: hidden size |
||||
:type hidden_size: int |
||||
:param num_attention_heads: number of attention heads |
||||
:type num_attention_heads: int |
||||
:param act_func: activation function, defaults to 'gelu' |
||||
:type act_func: str, optional |
||||
:param mlp_ratio: hidden size of MLP divided by embedding dim, defaults to 4.0 |
||||
:type mlp_ratio: float, optional |
||||
:param attention_dropout_prob: dropout probability for attention layer, defaults to 0. |
||||
:type attention_dropout_prob: float, optional |
||||
:param hidden_dropout_prob: dropout probability for attention layer, defaults to 0. |
||||
:type hidden_dropout_prob: float, optional |
||||
:param dtype: dtype of parameters, defaults to None |
||||
:type dtype: torch.dtype, optional |
||||
""" |
||||
|
||||
def __init__(self, |
||||
hidden_size: int, |
||||
num_attention_heads: int, |
||||
act_func: str = 'gelu', |
||||
mlp_ratio: float = 4.0, |
||||
attention_dropout_prob: float = 0., |
||||
hidden_dropout_prob: float = 0., |
||||
dtype=None, |
||||
): |
||||
super().__init__() |
||||
|
||||
self.attention = TransformerSelfAttention2D( |
||||
hidden_size=hidden_size, |
||||
num_attention_heads=num_attention_heads, |
||||
attention_dropout_prob=attention_dropout_prob, |
||||
hidden_dropout_prob=hidden_dropout_prob, |
||||
dtype=dtype, |
||||
) |
||||
self.mlp = TransformerMLP2D( |
||||
in_features=hidden_size, |
||||
dropout_prob=hidden_dropout_prob, |
||||
act_func=act_func, |
||||
mlp_ratio=mlp_ratio, |
||||
dtype=dtype, |
||||
) |
||||
|
||||
def forward(self, hidden_states: Tensor, attention_mask: Tensor) -> Tensor: |
||||
attention_output = self.attention(hidden_states, attention_mask) |
||||
output = self.mlp(attention_output) |
||||
return output |
@ -0,0 +1,23 @@
|
||||
import os |
||||
|
||||
from colossalai.context.parallel_mode import ParallelMode |
||||
from colossalai.context.process_group_initializer.initializer_2d import SUMMA_DIM |
||||
from colossalai.core import global_context as gpc |
||||
|
||||
|
||||
def get_summa_dim_from_env() -> int: |
||||
try: |
||||
summa_dim = os.environ[SUMMA_DIM] |
||||
summa_dim = int(summa_dim) |
||||
assert summa_dim > 0, 'SUMMA_DIM must be larger than zero' |
||||
return summa_dim |
||||
|
||||
except KeyError as e: |
||||
raise EnvironmentError('SUMMA_DIM is not found in the current environment, ' |
||||
'please make sure that you have used the correct process group initializer') |
||||
|
||||
|
||||
def assert_summa_initialization(): |
||||
assert gpc.is_initialized(ParallelMode.PARALLEL_2D_COL) and \ |
||||
gpc.is_initialized(ParallelMode.PARALLEL_2D_ROW), \ |
||||
'Both TWO_DIMENSION_COL and TWO_DIMENSION_ROW must be initialized by the process group initializer' |
@ -0,0 +1,391 @@
|
||||
#!/usr/bin/env python |
||||
# -*- encoding: utf-8 -*- |
||||
|
||||
import math |
||||
|
||||
import torch |
||||
from torch import nn as nn, Tensor, distributed as dist |
||||
|
||||
from colossalai.context import seed, ParallelMode |
||||
from colossalai.core import global_context as gpc |
||||
from colossalai.nn.layer._common_utils import divide, ACT2FN |
||||
from colossalai.nn.layer.parallel_2d._utils import assert_summa_initialization, get_summa_dim_from_env |
||||
from colossalai.nn.layer.vanilla_vision_transformer.layers import to_2tuple |
||||
from colossalai.registry import LAYERS |
||||
from colossalai.utils import checkpoint |
||||
from colossalai.utils import get_current_device |
||||
from ._operation import _ViT_Split_Input_2D |
||||
from .layers import Linear2D |
||||
from .._common_utils import set_tensor_parallel_attribute |
||||
from ..base_layer import ParallelLayer |
||||
|
||||
|
||||
@LAYERS.register_module |
||||
class ViTMLP2D(ParallelLayer): |
||||
"""MLP layer for 2D parallel Vision Transformer |
||||
|
||||
:param in_features: size of each input sample |
||||
:type in_features: int |
||||
:param mlp_ratio: hidden size of MLP divided by embedding dim |
||||
:type mlp_ratio: int |
||||
:param act_func: activation function, defaults to 'gelu' |
||||
:type act_func: str, optional |
||||
:param dropout_prob: dropout probability, defaults to 0. |
||||
:type dropout_prob: float, optional |
||||
:param dtype: The dtype of parameters, defaults to None |
||||
:type dtype: torch.dtype, optional |
||||
:param checkpoint: whether to checkpoint the layer, defaults to False |
||||
:type checkpoint: bool, optional |
||||
""" |
||||
|
||||
def __init__(self, |
||||
in_features: int, |
||||
mlp_ratio: int, |
||||
act_func: str = 'gelu', |
||||
dropout_prob: float = 0., |
||||
dtype=None, |
||||
checkpoint: bool = False |
||||
): |
||||
super().__init__() |
||||
|
||||
assert_summa_initialization() |
||||
self.summa_dim = get_summa_dim_from_env() |
||||
self.in_features = in_features |
||||
self.mlp_ratio = mlp_ratio |
||||
self.checkpoint = checkpoint |
||||
|
||||
# Project to mlp_ratio * h. |
||||
self.dense_1 = Linear2D( |
||||
self.in_features, |
||||
self.mlp_ratio * self.in_features, |
||||
dtype=dtype, |
||||
) |
||||
|
||||
self.act = ACT2FN[act_func] |
||||
|
||||
# Project back to h. |
||||
self.dense_2 = Linear2D( |
||||
self.mlp_ratio * self.in_features, |
||||
self.in_features, |
||||
dtype=dtype, |
||||
) |
||||
self.dropout = nn.Dropout(dropout_prob) |
||||
|
||||
def _forward(self, hidden_states: Tensor) -> Tensor: |
||||
intermediate_output = self.dense_1(hidden_states) |
||||
intermediate_output = self.act(intermediate_output) |
||||
|
||||
with seed(ParallelMode.TENSOR): |
||||
intermediate_output = self.dropout(intermediate_output) |
||||
output = self.dense_2(intermediate_output) |
||||
|
||||
with seed(ParallelMode.TENSOR): |
||||
output = self.dropout(output) |
||||
return output |
||||
|
||||
def _checkpoint_forward(self, hidden_states: Tensor) -> Tensor: |
||||
return checkpoint(self._forward, hidden_states) |
||||
|
||||
def forward(self, hidden_states: Tensor) -> Tensor: |
||||
if self.checkpoint: |
||||
return self._checkpoint_forward(hidden_states) |
||||
else: |
||||
return self._forward(hidden_states) |
||||
|
||||
|
||||
@LAYERS.register_module |
||||
class ViTSelfAttention2D(ParallelLayer): |
||||
"""Self-attention layer for 2D parallel Vision Transformer |
||||
|
||||
:param hidden_size: hidden size |
||||
:type hidden_size: int |
||||
:param num_attention_heads: number of attention heads |
||||
:type num_attention_heads: int |
||||
:param attention_dropout_prob: dropout probability for attention layers |
||||
:type attention_dropout_prob: float |
||||
:param hidden_dropout_prob: dropout probability for hidden layers |
||||
:type hidden_dropout_prob: float |
||||
:param dtype: dtype of parameters, defaults to None |
||||
:type dtype: torch.dtype, optional |
||||
:param checkpoint: whether to checkpoint the layer, defaults to False |
||||
:type checkpoint: bool, optional |
||||
""" |
||||
|
||||
def __init__(self, |
||||
hidden_size: int, |
||||
num_attention_heads: int, |
||||
attention_dropout_prob: float, |
||||
hidden_dropout_prob: float, |
||||
dtype=None, |
||||
checkpoint: bool = False |
||||
): |
||||
super().__init__() |
||||
|
||||
assert_summa_initialization() |
||||
self.summa_dim = get_summa_dim_from_env() |
||||
self.hidden_size = hidden_size |
||||
self.num_attention_heads = divide(num_attention_heads, self.summa_dim) |
||||
self.attention_head_size = divide(hidden_size, num_attention_heads) |
||||
self.all_head_size = self.num_attention_heads * self.attention_head_size |
||||
self.checkpoint = checkpoint |
||||
|
||||
self.query_key_value = Linear2D( |
||||
hidden_size, |
||||
3 * hidden_size, |
||||
dtype=dtype, |
||||
) |
||||
self.attention_dropout = nn.Dropout(attention_dropout_prob) |
||||
self.dense = Linear2D( |
||||
hidden_size, |
||||
hidden_size, |
||||
dtype=dtype, |
||||
) |
||||
self.dropout = nn.Dropout(hidden_dropout_prob) |
||||
self.softmax = nn.Softmax(dim=-1) |
||||
|
||||
def _forward(self, hidden_states: Tensor) -> Tensor: |
||||
query_key_value = self.query_key_value(hidden_states) |
||||
new_qkv_shape = query_key_value.shape[:-1] + \ |
||||
(self.num_attention_heads, 3 * self.attention_head_size) |
||||
query_key_value = query_key_value.view(new_qkv_shape) |
||||
query_key_value = query_key_value.permute((0, 2, 1, 3)) |
||||
query_layer, key_layer, value_layer = torch.chunk( |
||||
query_key_value, 3, dim=-1) |
||||
|
||||
attention_scores = torch.matmul( |
||||
query_layer, key_layer.transpose(-1, -2)) |
||||
attention_scores = attention_scores / \ |
||||
math.sqrt(self.attention_head_size) |
||||
|
||||
attention_probs = self.softmax(attention_scores) |
||||
|
||||
with seed(ParallelMode.TENSOR): |
||||
attention_probs = self.attention_dropout(attention_probs) |
||||
|
||||
context_layer = torch.matmul(attention_probs, value_layer) |
||||
context_layer = context_layer.transpose(1, 2) |
||||
new_context_layer_shape = context_layer.size()[ |
||||
:-2] + (self.all_head_size,) |
||||
context_layer = context_layer.reshape(new_context_layer_shape) |
||||
|
||||
output = self.dense(context_layer) |
||||
with seed(ParallelMode.TENSOR): |
||||
output = self.dropout(output) |
||||
return output |
||||
|
||||
def _checkpoint_forward(self, hidden_states: Tensor) -> Tensor: |
||||
return checkpoint(self._forward, hidden_states) |
||||
|
||||
def forward(self, hidden_states: Tensor) -> Tensor: |
||||
if self.checkpoint: |
||||
return self._checkpoint_forward(hidden_states) |
||||
else: |
||||
return self._forward(hidden_states) |
||||
|
||||
|
||||
@LAYERS.register_module |
||||
class ViTHead2D(ParallelLayer): |
||||
"""Output layer for 2D parallel Vision Transformer |
||||
|
||||
:param hidden_size: hidden size |
||||
:type hidden_size: int |
||||
:param num_classes: number of classes |
||||
:type num_classes: int |
||||
:param dtype: dtype of parameters, defaults to None |
||||
:type dtype: torch.dtype, optional |
||||
""" |
||||
|
||||
def __init__(self, |
||||
hidden_size, |
||||
num_classes, |
||||
dtype=None, |
||||
): |
||||
super().__init__() |
||||
assert_summa_initialization() |
||||
self.summa_dim = get_summa_dim_from_env() |
||||
self.linear = Linear2D( |
||||
hidden_size, |
||||
num_classes, |
||||
dtype=dtype, |
||||
) |
||||
|
||||
def forward(self, x: Tensor) -> Tensor: |
||||
x = x[:, 0] |
||||
x = self.linear(x) |
||||
return x |
||||
|
||||
|
||||
@LAYERS.register_module |
||||
class ViTPatchEmbedding2D(ParallelLayer): |
||||
""" 2D Image to Patch Embedding |
||||
|
||||
:param img_size: iamge size |
||||
:type img_size: int |
||||
:param patch_size: patch size |
||||
:type patch_size: int |
||||
:param embed_dim: dimension of embedding |
||||
:type embed_dim: int |
||||
:param in_chans: number of channels of input image, defaults to 3 |
||||
:type in_chans: int, optional |
||||
:param flatten: whether to flatten output tensor, defaults to True |
||||
:type flatten: bool, optional |
||||
""" |
||||
|
||||
def __init__(self, |
||||
img_size, |
||||
patch_size, |
||||
embed_dim, |
||||
in_chans=3, |
||||
flatten=True): |
||||
super().__init__() |
||||
img_size = to_2tuple(img_size) |
||||
patch_size = to_2tuple(patch_size) |
||||
|
||||
assert_summa_initialization() |
||||
self.summa_dim = get_summa_dim_from_env() |
||||
self.img_size = img_size |
||||
self.patch_size = patch_size |
||||
self.grid_size = (img_size[0] // patch_size[0], |
||||
img_size[1] // patch_size[1]) |
||||
self.num_patches = self.grid_size[0] * self.grid_size[1] |
||||
self.flatten = flatten |
||||
self.embed_dim = embed_dim // self.summa_dim |
||||
|
||||
with seed(ParallelMode.TENSOR): |
||||
# ensure the partitions are initialized differently |
||||
self.proj = nn.Conv2d(in_chans, |
||||
self.embed_dim, |
||||
kernel_size=patch_size, |
||||
stride=patch_size |
||||
) |
||||
|
||||
# sync |
||||
self._broadcast_conv_params() |
||||
self.proj.weight.register_hook(self._sync_grad_during_backward) |
||||
self.proj.bias.register_hook(self._sync_grad_during_backward) |
||||
|
||||
def _set_tensor_parallel_attribute(self): |
||||
set_tensor_parallel_attribute(self.proj.weight) |
||||
set_tensor_parallel_attribute(self.proj.bias) |
||||
|
||||
def _broadcast_conv_params(self) -> None: |
||||
self.to(get_current_device()) |
||||
ranks_in_col = gpc.get_ranks_in_group(ParallelMode.PARALLEL_2D_COL) |
||||
|
||||
dist.broadcast(self.proj.weight, src=ranks_in_col[0], |
||||
group=gpc.get_group(ParallelMode.PARALLEL_2D_COL)) |
||||
dist.broadcast(self.proj.bias, src=ranks_in_col[0], |
||||
group=gpc.get_group(ParallelMode.PARALLEL_2D_COL)) |
||||
|
||||
def _sync_grad_during_backward(self, grad: Tensor) -> None: |
||||
dist.all_reduce(grad, group=gpc.get_group( |
||||
ParallelMode.PARALLEL_2D_COL)) |
||||
grad = grad / self.summa_dim |
||||
return grad |
||||
|
||||
def forward(self, x: Tensor) -> Tensor: |
||||
B, C, H, W = x.shape |
||||
assert H == self.img_size[0] and W == self.img_size[1], \ |
||||
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." |
||||
x = self.proj(x) |
||||
if self.flatten: |
||||
x = x.flatten(2).transpose(1, 2) # BCHW -> BNC |
||||
return x |
||||
|
||||
|
||||
@LAYERS.register_module |
||||
class ViTTokenFuser2D(ParallelLayer): |
||||
""" |
||||
Fuse cls token and pos embedding to the input |
||||
|
||||
:param img_size: image size |
||||
:type img_size: int |
||||
:param patch_size: patch size |
||||
:type patch_size: int |
||||
:param embed_dim: dimension of embedding |
||||
:type embed_dim: int |
||||
:param drop_rate: dropout probability, defaults to 0. |
||||
:type drop_rate: float, optional |
||||
""" |
||||
|
||||
def __init__(self, |
||||
img_size, |
||||
patch_size, |
||||
embed_dim, |
||||
drop_rate=0. |
||||
): |
||||
super().__init__() |
||||
img_size = to_2tuple(img_size) |
||||
patch_size = to_2tuple(patch_size) |
||||
|
||||
assert_summa_initialization() |
||||
self.summa_dim = get_summa_dim_from_env() |
||||
self.img_size = img_size |
||||
self.patch_size = patch_size |
||||
self.grid_size = (img_size[0] // patch_size[0], |
||||
img_size[1] // patch_size[1]) |
||||
self.num_patches = self.grid_size[0] * self.grid_size[1] |
||||
self.embed_dim = embed_dim |
||||
|
||||
self.cls_token = nn.Parameter(torch.zeros( |
||||
1, 1, self.embed_dim // self.summa_dim)) |
||||
self.pos_embed = nn.Parameter(torch.zeros( |
||||
1, self.num_patches + 1, self.embed_dim // self.summa_dim)) |
||||
|
||||
# move to cuda before broadcast |
||||
self.to(get_current_device()) |
||||
|
||||
# sync param in both forward and backward |
||||
_cls_token = self.cls_token.view(-1) |
||||
_pos_embed = self.pos_embed.view(-1) |
||||
self._param = torch.cat([_cls_token, _pos_embed], dim=0) |
||||
|
||||
self._broadcast_params(self._param) |
||||
self._param.register_hook(self._sync_grad_hook) |
||||
self.pos_drop = nn.Dropout(p=drop_rate) |
||||
self._set_tensor_parallel_attribute() |
||||
|
||||
def _set_tensor_parallel_attribute(self): |
||||
set_tensor_parallel_attribute(self.cls_token) |
||||
set_tensor_parallel_attribute(self.pos_embed) |
||||
|
||||
def _broadcast_params(self, param) -> None: |
||||
" broadcast to all column ranks for data consistency " |
||||
ranks_in_col = gpc.get_ranks_in_group(ParallelMode.PARALLEL_2D_COL) |
||||
col_group = gpc.get_group(ParallelMode.PARALLEL_2D_COL) |
||||
dist.broadcast(param, src=ranks_in_col[0], |
||||
group=col_group) |
||||
|
||||
def _sync_grad_hook(self, grad) -> None: |
||||
dist.all_reduce(grad, group=gpc.get_group( |
||||
ParallelMode.PARALLEL_2D_COL)) |
||||
grad = grad / self.summa_dim |
||||
return grad |
||||
|
||||
def forward(self, x: Tensor) -> Tensor: |
||||
# stole cls_tokens impl from Phil Wang, thanks |
||||
cls_token = self.cls_token.expand(x.shape[0], -1, -1) |
||||
x = torch.cat((cls_token, x), dim=1) |
||||
with seed(ParallelMode.TENSOR): |
||||
x = self.pos_drop(x + self.pos_embed) |
||||
return x |
||||
|
||||
|
||||
@LAYERS.register_module |
||||
class ViTInputSplitter2D(ParallelLayer): |
||||
"""Split the input tensor for 2D parallel Vision Transformer |
||||
""" |
||||
|
||||
def __init__(self): |
||||
super().__init__() |
||||
assert_summa_initialization() |
||||
self.summa_dim = get_summa_dim_from_env() |
||||
|
||||
def forward(self, x: Tensor) -> Tensor: |
||||
batch_size = x.size(0) |
||||
return _ViT_Split_Input_2D.apply( |
||||
x, |
||||
batch_size, |
||||
self.summa_dim, |
||||
ParallelMode.PARALLEL_2D_COL |
||||
) |
@ -0,0 +1,258 @@
|
||||
import math |
||||
|
||||
import torch |
||||
import torch.distributed as dist |
||||
from torch import Tensor |
||||
from torch.nn import Parameter, init as init |
||||
|
||||
from colossalai.context import seed, ParallelMode |
||||
from colossalai.core import global_context as gpc |
||||
from colossalai.registry import LAYERS |
||||
from colossalai.utils import get_current_device |
||||
from ._operation import Matmul_AB_2D, Add_Bias_2D, _LayerNorm_2D |
||||
from ._utils import get_summa_dim_from_env, assert_summa_initialization |
||||
from .._common_utils import divide, set_tensor_parallel_attribute |
||||
from ..base_layer import ParallelLayer |
||||
|
||||
|
||||
@LAYERS.register_module |
||||
class Linear2D(ParallelLayer): |
||||
""" Linear layer for 2D parallelism |
||||
|
||||
:param in_features: size of each input sample |
||||
:type in_features: int |
||||
:param out_features: size of each output sample |
||||
:type out_features: int |
||||
:param bias: If set to ``False``, the layer will not learn an additive bias, defaults to True |
||||
:type bias: bool, optional |
||||
:param dtype: The dtype of parameters, defaults to None |
||||
:type dtype: torch.dtype, optional |
||||
:param skip_bias_add: If set to ``True``, it will skip bias add for linear layer, which is preserved for kernel fusion, defaults to False |
||||
:type skip_bias_add: bool, optional |
||||
""" |
||||
|
||||
def __init__(self, |
||||
in_features: int, |
||||
out_features: int, |
||||
bias: bool = True, |
||||
dtype=None, |
||||
skip_bias_add: bool = False |
||||
): |
||||
super().__init__() |
||||
|
||||
self.in_features = in_features |
||||
self.out_features = out_features |
||||
self.skip_bias_add = skip_bias_add |
||||
|
||||
# parallel settings |
||||
assert_summa_initialization() |
||||
self.row_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) |
||||
self.col_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW) |
||||
self.summa_dim = get_summa_dim_from_env() |
||||
|
||||
# partitioning dimension |
||||
self.input_size_per_partition = divide( |
||||
self.in_features, self.summa_dim) |
||||
self.hidden_size_per_partition = divide( |
||||
self.out_features, self.summa_dim) |
||||
|
||||
# create weight, shape: [k/q, h/q] |
||||
factory_kwargs = {'device': get_current_device(), 'dtype': dtype} |
||||
self.weight = Parameter(torch.empty( |
||||
self.input_size_per_partition, |
||||
self.hidden_size_per_partition, |
||||
**factory_kwargs)) |
||||
|
||||
# create bias, shape: [h/q] |
||||
if bias: |
||||
self.bias = Parameter(torch.empty( |
||||
self.hidden_size_per_partition, |
||||
**factory_kwargs)) |
||||
else: |
||||
self.register_parameter('bias', None) |
||||
|
||||
# initialize parameters |
||||
self.reset_parameters() |
||||
self._set_tensor_parallel_attributes() |
||||
|
||||
def _set_tensor_parallel_attributes(self): |
||||
set_tensor_parallel_attribute(self.weight) |
||||
if self.bias is not None: |
||||
set_tensor_parallel_attribute(self.bias) |
||||
|
||||
def reset_parameters(self) -> None: |
||||
# setting |
||||
fan_in = self.in_features |
||||
a = math.sqrt(5) |
||||
nonlinearity = 'leaky_relu' |
||||
|
||||
# init weight |
||||
std = init.calculate_gain(nonlinearity, a) / math.sqrt(fan_in) |
||||
bound = math.sqrt(3.0) * std |
||||
with seed(ParallelMode.TENSOR): |
||||
init.uniform_(self.weight, -bound, bound) |
||||
|
||||
# init bias |
||||
if self.bias is not None: |
||||
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 |
||||
with seed(ParallelMode.TENSOR): |
||||
init.uniform_(self.bias, -bound, bound) |
||||
|
||||
def forward(self, x: Tensor) -> Tensor: |
||||
# input: [m/q, n/q, k/q] |
||||
# output: [m/q, n/q, h/q] |
||||
out_shape = x.shape[:-1] + (self.hidden_size_per_partition,) |
||||
|
||||
output = Matmul_AB_2D.apply( |
||||
x, |
||||
self.weight, |
||||
self.summa_dim, |
||||
out_shape, |
||||
self.row_rank, |
||||
self.col_rank, |
||||
ParallelMode.PARALLEL_2D_ROW, |
||||
ParallelMode.PARALLEL_2D_COL, |
||||
self.data_parallel_rank, |
||||
self.pipeline_parallel_rank, |
||||
self.pipeline_parallel_size, |
||||
self.tensor_parallel_size) |
||||
|
||||
if self.bias is not None: |
||||
if self.skip_bias_add: |
||||
bias = Add_Bias_2D.apply( |
||||
None, |
||||
self.bias, |
||||
self.hidden_size_per_partition, |
||||
self.row_rank, |
||||
self.col_rank, |
||||
ParallelMode.PARALLEL_2D_ROW, |
||||
ParallelMode.PARALLEL_2D_COL, |
||||
True, |
||||
self.data_parallel_rank, |
||||
self.pipeline_parallel_rank, |
||||
self.pipeline_parallel_size, |
||||
self.tensor_parallel_size |
||||
) |
||||
return output, bias |
||||
else: |
||||
output = Add_Bias_2D.apply( |
||||
output, |
||||
self.bias, |
||||
self.hidden_size_per_partition, |
||||
self.row_rank, |
||||
self.col_rank, |
||||
ParallelMode.PARALLEL_2D_ROW, |
||||
ParallelMode.PARALLEL_2D_COL, |
||||
False, |
||||
self.data_parallel_rank, |
||||
self.pipeline_parallel_rank, |
||||
self.pipeline_parallel_size, |
||||
self.tensor_parallel_size |
||||
) |
||||
return output |
||||
else: |
||||
return output |
||||
|
||||
|
||||
@LAYERS.register_module |
||||
class LayerNorm2D(ParallelLayer): |
||||
r"""Layer Normalization for 2D parallelism |
||||
|
||||
:param normalized_shape: input shape from an expected input |
||||
of size. :math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1] \times \ldots \times \text{normalized_shape}[-1]]` |
||||
If a single integer is used, it is treated as a singleton list, and this module will |
||||
normalize over the last dimension which is expected to be of that specific size. |
||||
:type normalized_shape: int |
||||
:param eps: a value added to the denominator for numerical stability, defaults to 1e-05 |
||||
:type eps: float, optional |
||||
:param dtype: The dtype of parameters, defaults to None |
||||
:type dtype: torch.dtype, optional |
||||
""" |
||||
|
||||
def __init__(self, |
||||
normalized_shape: int, |
||||
eps: float = 1e-05, |
||||
dtype=None |
||||
): |
||||
super().__init__() |
||||
|
||||
# layer norm config |
||||
self.normalized_shape = normalized_shape |
||||
self.variance_epsilon = eps |
||||
|
||||
# parallel setting |
||||
assert_summa_initialization() |
||||
self.row_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) |
||||
self.col_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW) |
||||
self.summa_dim = get_summa_dim_from_env() |
||||
|
||||
# partitioning dimension |
||||
self.partitioned_partition = divide(normalized_shape, self.summa_dim) |
||||
|
||||
# create parameters |
||||
factory_kwargs = {'device': get_current_device(), 'dtype': dtype} |
||||
|
||||
if self.row_rank == 0: |
||||
self.gamma = Parameter(torch.ones( |
||||
self.partitioned_partition, |
||||
**factory_kwargs)) |
||||
self.beta = Parameter(torch.zeros( |
||||
self.partitioned_partition, |
||||
**factory_kwargs)) |
||||
else: |
||||
self.gamma = Parameter(torch.tensor( |
||||
1.0, |
||||
requires_grad=True, |
||||
**factory_kwargs)) |
||||
self.beta = Parameter(torch.tensor( |
||||
1.0, |
||||
requires_grad=True, |
||||
**factory_kwargs)) |
||||
|
||||
self._set_tensor_parallel_attributes() |
||||
|
||||
def _set_tensor_parallel_attributes(self): |
||||
set_tensor_parallel_attribute(self.gamma) |
||||
set_tensor_parallel_attribute(self.beta) |
||||
|
||||
def forward(self, x: Tensor) -> Tensor: |
||||
with torch.no_grad(): |
||||
E_x = torch.sum(x, dim=-1, keepdim=True) # [b/q, s, 1] |
||||
torch.distributed.all_reduce( |
||||
E_x, group=gpc.get_group(ParallelMode.PARALLEL_2D_ROW)) |
||||
E_x /= self.normalized_shape |
||||
|
||||
# Var_x in the block below is the sum of input^2 |
||||
Var_x = torch.sum(x * x, dim=-1, keepdim=True) # [b/q, s, 1] |
||||
torch.distributed.all_reduce( |
||||
Var_x, group=gpc.get_group(ParallelMode.PARALLEL_2D_ROW)) |
||||
Var_x /= self.normalized_shape |
||||
|
||||
Var_x = Var_x - E_x * E_x # variance of x [b/q, s, 1] |
||||
# this time 1/sqrt(Var_x + epsilon) |
||||
Var_x = 1.0 / torch.sqrt(Var_x + self.variance_epsilon) |
||||
|
||||
output = _LayerNorm_2D.apply(x, E_x, Var_x, self.normalized_shape, |
||||
ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL) |
||||
bias = Add_Bias_2D.apply( |
||||
None, self.beta, self.partitioned_partition, |
||||
self.row_rank, self.col_rank, |
||||
ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL, |
||||
True, |
||||
self.data_parallel_rank, |
||||
self.pipeline_parallel_rank, |
||||
self.pipeline_parallel_size, |
||||
self.tensor_parallel_size |
||||
) |
||||
scale = Add_Bias_2D.apply( |
||||
None, self.gamma, self.partitioned_partition, |
||||
self.row_rank, self.col_rank, |
||||
ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL, |
||||
True, |
||||
self.data_parallel_rank, |
||||
self.pipeline_parallel_rank, |
||||
self.pipeline_parallel_size, |
||||
self.tensor_parallel_size |
||||
) |
||||
output = torch.addcmul(bias, scale, output) |
||||
return output |
@ -0,0 +1,13 @@
|
||||
from ._operation import Matmul_AB_2p5D, Matmul_ABT_2p5D, Matmul_ATB_2p5D, Sum_2p5D, Add_Bias_2p5D |
||||
from ._transformer import TransformerMLP2p5D, TransformerSelfAttention2p5D, TransformerLayer2p5D |
||||
from ._vit import (ViTMLP2p5D, ViTSelfAttention2p5D, ViTHead2p5D, ViTPatchEmbedding2p5D, ViTTokenFuser2p5D, |
||||
ViTInputSplitter2p5D) |
||||
from .layers import Linear2p5D, LayerNorm2p5D |
||||
|
||||
__all__ = [ |
||||
'Matmul_AB_2p5D', 'Matmul_ABT_2p5D', 'Matmul_ATB_2p5D', 'Sum_2p5D', 'Add_Bias_2p5D', |
||||
'TransformerMLP2p5D', 'TransformerSelfAttention2p5D', 'TransformerLayer2p5D', |
||||
'ViTMLP2p5D', 'ViTSelfAttention2p5D', 'ViTHead2p5D', 'ViTPatchEmbedding2p5D', 'ViTTokenFuser2p5D', |
||||
'ViTInputSplitter2p5D', |
||||
'Linear2p5D', 'LayerNorm2p5D' |
||||
] |
@ -0,0 +1,535 @@
|
||||
from typing import Any, Tuple |
||||
|
||||
import torch |
||||
import torch.distributed as dist |
||||
from torch import Tensor |
||||
|
||||
from colossalai.context.parallel_mode import ParallelMode |
||||
from colossalai.core import global_context as gpc |
||||
from colossalai.utils import get_current_device, empty_cache |
||||
|
||||
|
||||
def get_parallel_group(parallel_mode: ParallelMode): |
||||
return gpc.get_group(parallel_mode) |
||||
|
||||
|
||||
def get_global_rank(): |
||||
return gpc.get_global_rank() |
||||
|
||||
|
||||
def get_parallel_rank(parallel_mode: ParallelMode): |
||||
return gpc.get_local_rank(parallel_mode) |
||||
|
||||
|
||||
class Matmul_AB_2p5D(torch.autograd.Function): |
||||
"""Matrix multiplication for :math:`C = AB` |
||||
""" |
||||
|
||||
@staticmethod |
||||
def forward(ctx: Any, |
||||
A: Tensor, |
||||
B: Tensor, |
||||
tesseract_dim: int, |
||||
tesseract_dep: int, |
||||
out_shape: Tuple[int, ...], |
||||
row_rank: int, |
||||
col_rank: int, |
||||
dep_rank: int, |
||||
row_parallel_mode: ParallelMode, |
||||
col_parallel_mode: ParallelMode, |
||||
dep_parallel_mode: ParallelMode, |
||||
data_parallel_rank: int, |
||||
pipeline_parallel_rank: int, |
||||
pipeline_parallel_size: int, |
||||
tensor_parallel_size: int) -> Tensor: |
||||
# A: [b / dq, s, h / q] -> [(b * s) / dq, h / q] |
||||
# B: [h / dq, s / q] |
||||
# C: [b / dq, s, s / q] -> [(b * s) / dq, s / q] |
||||
|
||||
assert A.shape[-1] == B.shape[-2], \ |
||||
'Invalid shapes: A={}, B={} for AB.'.format(A.shape, B.shape) |
||||
|
||||
empty_cache() |
||||
if ctx: |
||||
ctx.save_for_backward(A, B) |
||||
|
||||
A_shape = A.shape |
||||
A = A.reshape((-1, A_shape[-1])) |
||||
B_shape = B.shape |
||||
B = B.reshape((-1, B_shape[-1])) |
||||
C_shape = (A.shape[0], B.shape[-1]) |
||||
C = torch.zeros(C_shape, dtype=A.dtype, device=get_current_device()) |
||||
|
||||
for i in range(tesseract_dim): |
||||
A_temp = A.clone() |
||||
B_temp = B.clone() |
||||
src_a = i + row_rank * tesseract_dim + dep_rank * ( |
||||
tesseract_dim ** 2) + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + pipeline_parallel_rank * tensor_parallel_size |
||||
dist.broadcast(A_temp, src=src_a, |
||||
group=get_parallel_group(row_parallel_mode)) |
||||
src_b = col_rank + i * tesseract_dim + dep_rank * ( |
||||
tesseract_dim ** 2) + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + pipeline_parallel_rank * tensor_parallel_size |
||||
dist.broadcast(B_temp, src=src_b, |
||||
group=get_parallel_group(col_parallel_mode)) |
||||
torch.addmm(C, A_temp, B_temp, out=C) |
||||
|
||||
out = C.reshape(out_shape) |
||||
|
||||
if ctx: |
||||
ctx.tesseract_dim = tesseract_dim |
||||
ctx.tesseract_dep = tesseract_dep |
||||
ctx.row_rank = row_rank |
||||
ctx.col_rank = col_rank |
||||
ctx.dep_rank = dep_rank |
||||
ctx.row_parallel_mode = row_parallel_mode |
||||
ctx.col_parallel_mode = col_parallel_mode |
||||
ctx.dep_parallel_mode = dep_parallel_mode |
||||
ctx.A_shape = A_shape |
||||
ctx.B_shape = B_shape |
||||
ctx.data_parallel_rank = data_parallel_rank |
||||
ctx.pipeline_parallel_rank = pipeline_parallel_rank |
||||
ctx.pipeline_parallel_size = pipeline_parallel_size |
||||
ctx.tensor_parallel_size = tensor_parallel_size |
||||
|
||||
return out |
||||
|
||||
@staticmethod |
||||
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]: |
||||
A, B = ctx.saved_tensors |
||||
A_grad = Matmul_ABT_2p5D.forward( |
||||
None, |
||||
output_grad, B, |
||||
ctx.tesseract_dim, ctx.tesseract_dep, ctx.A_shape, |
||||
ctx.row_rank, ctx.col_rank, ctx.dep_rank, |
||||
ctx.row_parallel_mode, |
||||
ctx.col_parallel_mode, |
||||
ctx.dep_parallel_mode, |
||||
ctx.data_parallel_rank, |
||||
ctx.pipeline_parallel_rank, |
||||
ctx.pipeline_parallel_size, |
||||
ctx.tensor_parallel_size |
||||
) |
||||
B_grad = Matmul_ATB_2p5D.forward( |
||||
None, |
||||
A, output_grad, |
||||
ctx.tesseract_dim, ctx.tesseract_dep, ctx.B_shape, |
||||
ctx.row_rank, ctx.col_rank, ctx.dep_rank, |
||||
ctx.row_parallel_mode, |
||||
ctx.col_parallel_mode, |
||||
ctx.dep_parallel_mode, |
||||
ctx.data_parallel_rank, |
||||
ctx.pipeline_parallel_rank, |
||||
ctx.pipeline_parallel_size, |
||||
ctx.tensor_parallel_size |
||||
) |
||||
return A_grad, B_grad, None, None, None, None, None, None, None, None, None, None, None, None, None |
||||
|
||||
|
||||
class Matmul_ABT_2p5D(torch.autograd.Function): |
||||
"""Matrix multiplication for :math:`C = AB^T` |
||||
""" |
||||
|
||||
@staticmethod |
||||
def forward(ctx: Any, |
||||
A: Tensor, |
||||
B: Tensor, |
||||
tesseract_dim: int, |
||||
tesseract_dep: int, |
||||
out_shape: Tuple[int, ...], |
||||
row_rank: int, |
||||
col_rank: int, |
||||
dep_rank: int, |
||||
row_parallel_mode: ParallelMode, |
||||
col_parallel_mode: ParallelMode, |
||||
dep_parallel_mode: ParallelMode, |
||||
data_parallel_rank: int, |
||||
pipeline_parallel_rank: int, |
||||
pipeline_parallel_size: int, |
||||
tensor_parallel_size: int |
||||
) -> Tensor: |
||||
|
||||
assert A.shape[-1] == B.shape[-1], \ |
||||
'Invalid shapes: A={}, B={} for ABT.'.format(A.shape, B.shape) |
||||
|
||||
empty_cache() |
||||
if ctx: |
||||
ctx.save_for_backward(A, B) |
||||
|
||||
A_shape = A.shape |
||||
A = A.reshape((-1, A_shape[-1])) |
||||
B_shape = B.shape |
||||
B = B.reshape((-1, B_shape[-1])) |
||||
C_shape = (A.shape[0], B.shape[0]) |
||||
C = torch.empty(C_shape, dtype=A.dtype, device=get_current_device()) |
||||
|
||||
for i in range(tesseract_dim): |
||||
B_temp = B.clone() |
||||
src_b = col_rank + i * tesseract_dim + dep_rank * ( |
||||
tesseract_dim ** 2) + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \ |
||||
pipeline_parallel_rank * tensor_parallel_size |
||||
dist.broadcast(B_temp, src=src_b, group=gpc.get_group(col_parallel_mode)) |
||||
C_temp = torch.matmul(A, B_temp.transpose(0, 1)) |
||||
src_c = i + row_rank * tesseract_dim + dep_rank * ( |
||||
tesseract_dim ** 2) + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \ |
||||
pipeline_parallel_rank * tensor_parallel_size |
||||
dist.reduce(C_temp, dst=src_c, group=gpc.get_group(row_parallel_mode)) |
||||
if i == col_rank: |
||||
C = C_temp.clone() |
||||
|
||||
out = C.reshape(out_shape) |
||||
|
||||
if ctx: |
||||
ctx.tesseract_dim = tesseract_dim |
||||
ctx.tesseract_dep = tesseract_dep |
||||
ctx.row_rank = row_rank |
||||
ctx.col_rank = col_rank |
||||
ctx.dep_rank = dep_rank |
||||
ctx.row_parallel_mode = row_parallel_mode |
||||
ctx.col_parallel_mode = col_parallel_mode |
||||
ctx.dep_parallel_mode = dep_parallel_mode |
||||
ctx.A_shape = A_shape |
||||
ctx.B_shape = B_shape |
||||
ctx.data_parallel_rank = data_parallel_rank |
||||
ctx.pipeline_parallel_rank = pipeline_parallel_rank |
||||
ctx.pipeline_parallel_size = pipeline_parallel_size |
||||
ctx.tensor_parallel_size = tensor_parallel_size |
||||
|
||||
return out |
||||
|
||||
@staticmethod |
||||
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]: |
||||
A, B = ctx.saved_tensors |
||||
A_grad = Matmul_AB_2p5D.forward( |
||||
None, |
||||
output_grad, B, |
||||
ctx.tesseract_dim, ctx.tesseract_dep, ctx.A_shape, |
||||
ctx.row_rank, ctx.col_rank, ctx.dep_rank, |
||||
ctx.row_parallel_mode, |
||||
ctx.col_parallel_mode, |
||||
ctx.dep_parallel_mode, |
||||
ctx.data_parallel_rank, |
||||
ctx.pipeline_parallel_rank, |
||||
ctx.pipeline_parallel_size, |
||||
ctx.tensor_parallel_size |
||||
) |
||||
B_grad = Matmul_ATB_2p5D.forward( |
||||
None, |
||||
output_grad, A, |
||||
ctx.tesseract_dim, ctx.tesseract_dep, ctx.B_shape, |
||||
ctx.row_rank, ctx.col_rank, ctx.dep_rank, |
||||
ctx.row_parallel_mode, |
||||
ctx.col_parallel_mode, |
||||
ctx.dep_parallel_mode, |
||||
ctx.data_parallel_rank, |
||||
ctx.pipeline_parallel_rank, |
||||
ctx.pipeline_parallel_size, |
||||
ctx.tensor_parallel_size |
||||
) |
||||
return A_grad, B_grad, None, None, None, None, None, None, None, None, None, None, None, None, None |
||||
|
||||
|
||||
class Matmul_ATB_2p5D(torch.autograd.Function): |
||||
"""Matrix multiplication for :math:`C = A^TB` |
||||
""" |
||||
|
||||
@staticmethod |
||||
def forward(ctx: Any, |
||||
A: Tensor, |
||||
B: Tensor, |
||||
tesseract_dim: int, |
||||
tesseract_dep: int, |
||||
out_shape: Tuple[int, ...], |
||||
row_rank: int, |
||||
col_rank: int, |
||||
dep_rank: int, |
||||
row_parallel_mode: ParallelMode, |
||||
col_parallel_mode: ParallelMode, |
||||
dep_parallel_mode: ParallelMode, |
||||
data_parallel_rank: int, |
||||
pipeline_parallel_rank: int, |
||||
pipeline_parallel_size: int, |
||||
tensor_parallel_size: int): |
||||
|
||||
assert A.shape[-2] == B.shape[-2], \ |
||||
'Invalid shapes: A={}, B={} for ATB.'.format(A.shape, B.shape) |
||||
|
||||
empty_cache() |
||||
if ctx: |
||||
ctx.save_for_backward(A, B) |
||||
|
||||
A_shape = A.shape |
||||
A = A.reshape((-1, A_shape[-1])) |
||||
B_shape = B.shape |
||||
B = B.reshape((-1, B_shape[-1])) |
||||
C_shape = (A.shape[-1], B.shape[-1]) |
||||
C = torch.empty(C_shape, dtype=A.dtype, device=get_current_device()) |
||||
|
||||
for i in range(tesseract_dim): |
||||
A_temp = A.clone() |
||||
src_a = i + row_rank * tesseract_dim + dep_rank * ( |
||||
tesseract_dim ** 2) + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \ |
||||
pipeline_parallel_rank * tensor_parallel_size |
||||
dist.broadcast(A_temp, src=src_a, |
||||
group=get_parallel_group(row_parallel_mode)) |
||||
C_temp = torch.matmul(A_temp.transpose(0, 1), B) |
||||
src_c = col_rank + i * tesseract_dim + dep_rank * ( |
||||
tesseract_dim ** 2) + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \ |
||||
pipeline_parallel_rank * tensor_parallel_size |
||||
dist.reduce(C_temp, dst=src_c, |
||||
group=get_parallel_group(col_parallel_mode)) |
||||
if i == row_rank: |
||||
C = C_temp.clone() |
||||
|
||||
out = C.reshape(out_shape) |
||||
|
||||
if ctx: |
||||
ctx.tesseract_dim = tesseract_dim |
||||
ctx.tesseract_dep = tesseract_dep |
||||
ctx.row_rank = row_rank |
||||
ctx.col_rank = col_rank |
||||
ctx.dep_rank = dep_rank |
||||
ctx.row_parallel_mode = row_parallel_mode |
||||
ctx.col_parallel_mode = col_parallel_mode |
||||
ctx.dep_parallel_mode = dep_parallel_mode |
||||
ctx.A_shape = A_shape |
||||
ctx.B_shape = B_shape |
||||
ctx.data_parallel_rank = data_parallel_rank |
||||
ctx.pipeline_parallel_rank = pipeline_parallel_rank |
||||
ctx.pipeline_parallel_size = pipeline_parallel_size |
||||
ctx.tensor_parallel_size = tensor_parallel_size |
||||
|
||||
return out |
||||
|
||||
@staticmethod |
||||
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]: |
||||
A, B = ctx.saved_tensors |
||||
A_grad = Matmul_ABT_2p5D.forward( |
||||
None, |
||||
B, output_grad, |
||||
ctx.tesseract_dim, ctx.tesseract_dep, ctx.A_shape, |
||||
ctx.row_rank, ctx.col_rank, ctx.dep_rank, |
||||
ctx.row_parallel_mode, |
||||
ctx.col_parallel_mode, |
||||
ctx.dep_parallel_mode, |
||||
ctx.data_parallel_rank, |
||||
ctx.pipeline_parallel_rank, |
||||
ctx.pipeline_parallel_size, |
||||
ctx.tensor_parallel_size |
||||
) |
||||
B_grad = Matmul_AB_2p5D.forward( |
||||
None, |
||||
A, output_grad, |
||||
ctx.tesseract_dim, ctx.tesseract_dep, ctx.B_shape, |
||||
ctx.row_rank, ctx.col_rank, ctx.dep_rank, |
||||
ctx.row_parallel_mode, |
||||
ctx.col_parallel_mode, |
||||
ctx.dep_parallel_mode, |
||||
ctx.data_parallel_rank, |
||||
ctx.pipeline_parallel_rank, |
||||
ctx.pipeline_parallel_size, |
||||
ctx.tensor_parallel_size |
||||
) |
||||
return A_grad, B_grad, None, None, None, None, None, None, None, None, None, None, None, None, None |
||||
|
||||
|
||||
class Add_Bias_2p5D(torch.autograd.Function): |
||||
"""Matrix add bias: :math:`C = A + b` |
||||
""" |
||||
|
||||
@staticmethod |
||||
def forward(ctx: Any, |
||||
input: Tensor, |
||||
bias: Tensor, |
||||
output_size_per_partition: int, |
||||
tesseract_dim: int, |
||||
tesseract_dep: int, |
||||
row_rank: int, |
||||
col_rank: int, |
||||
dep_rank: int, |
||||
row_parallel_mode: ParallelMode, |
||||
col_parallel_mode: ParallelMode, |
||||
dep_parallel_mode: ParallelMode, |
||||
skip_bias_add: bool, |
||||
data_parallel_rank: int, |
||||
pipeline_parallel_rank: int, |
||||
pipeline_parallel_size: int, |
||||
tensor_parallel_size: int |
||||
) -> Tensor: |
||||
if row_rank == 0: |
||||
bias_temp = bias.clone() |
||||
else: |
||||
bias_temp = torch.zeros( |
||||
output_size_per_partition, |
||||
dtype=bias.dtype, |
||||
device=get_current_device()) |
||||
src_rank = col_rank + dep_rank * ( |
||||
tesseract_dim ** 2) + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \ |
||||
pipeline_parallel_rank * tensor_parallel_size |
||||
dist.broadcast(bias_temp, src=src_rank, group=get_parallel_group(col_parallel_mode)) |
||||
|
||||
ctx.row_rank = row_rank |
||||
ctx.col_rank = col_rank |
||||
ctx.dep_rank = dep_rank |
||||
ctx.tesseract_dim = tesseract_dim |
||||
ctx.tesseract_dep = tesseract_dep |
||||
ctx.row_parallel_mode = row_parallel_mode |
||||
ctx.col_parallel_mode = col_parallel_mode |
||||
ctx.dep_parallel_mode = dep_parallel_mode |
||||
ctx.bias = skip_bias_add |
||||
ctx.data_parallel_rank = data_parallel_rank |
||||
ctx.pipeline_parallel_rank = pipeline_parallel_rank |
||||
ctx.pipeline_parallel_size = pipeline_parallel_size |
||||
ctx.tensor_parallel_size = tensor_parallel_size |
||||
|
||||
if skip_bias_add: |
||||
return bias_temp |
||||
else: |
||||
output = input + bias_temp |
||||
return output |
||||
|
||||
@staticmethod |
||||
def backward(ctx, output_grad): |
||||
row_rank = ctx.row_rank |
||||
col_rank = ctx.col_rank |
||||
dep_rank = ctx.dep_rank |
||||
tesseract_dim = ctx.tesseract_dim |
||||
tesseract_dep = ctx.tesseract_dep |
||||
row_parallel_mode = ctx.row_parallel_mode |
||||
col_parallel_mode = ctx.col_parallel_mode |
||||
dep_parallel_mode = ctx.dep_parallel_mode |
||||
data_parallel_rank = ctx.data_parallel_rank |
||||
pipeline_parallel_rank = ctx.pipeline_parallel_rank |
||||
pipeline_parallel_size = ctx.pipeline_parallel_size |
||||
tensor_parallel_size = ctx.tensor_parallel_size |
||||
|
||||
if ctx.bias: |
||||
dst_rank = col_rank + dep_rank * ( |
||||
tesseract_dim ** 2) + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \ |
||||
pipeline_parallel_rank * tensor_parallel_size |
||||
dist.reduce(output_grad, dst=dst_rank, group=get_parallel_group(col_parallel_mode)) |
||||
if row_rank == 0: |
||||
return None, output_grad, None, None, None, None, None, None, None, None, None, None, None, None, None, None |
||||
else: |
||||
grad_tmp = torch.zeros_like(output_grad) |
||||
return None, grad_tmp, None, None, None, None, None, None, None, None, None, None, None, None, None, None |
||||
else: |
||||
reduce_dim = tuple(range(output_grad.ndim - 1)) |
||||
reduce = torch.sum(output_grad, dim=reduce_dim) |
||||
dst_rank = col_rank + dep_rank * ( |
||||
tesseract_dim ** 2) + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \ |
||||
pipeline_parallel_rank * tensor_parallel_size |
||||
dist.reduce(reduce, dst=dst_rank, group=get_parallel_group(col_parallel_mode)) |
||||
if row_rank == 0: |
||||
return output_grad, reduce, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None |
||||
else: |
||||
reduce_tmp = torch.zeros_like(reduce) |
||||
return output_grad, reduce_tmp, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None |
||||
|
||||
|
||||
class _LayerNorm_2p5D(torch.autograd.Function): |
||||
@staticmethod |
||||
def forward(ctx: Any, |
||||
input: Tensor, |
||||
E_x: Tensor, |
||||
Var_x: Tensor, |
||||
hidden_size: int, |
||||
row_parallel_mode: ParallelMode, |
||||
col_parallel_mode: ParallelMode, |
||||
dep_parallel_mode: ParallelMode) -> Tensor: |
||||
input = input - E_x |
||||
# in here, input = x - E[x], Var_x = 1 / sqrt(Var[x] + eps) |
||||
ctx.hidden_size = hidden_size |
||||
output = input * Var_x |
||||
ctx.save_for_backward(output, Var_x) |
||||
ctx.row_parallel_mode = row_parallel_mode |
||||
ctx.col_parallel_mode = col_parallel_mode |
||||
ctx.dep_parallel_mode = dep_parallel_mode |
||||
return output |
||||
|
||||
@staticmethod |
||||
def backward(ctx, output_grad): |
||||
row_parallel_mode = ctx.row_parallel_mode |
||||
col_parallel_mode = ctx.col_parallel_mode |
||||
dep_parallel_mode = ctx.dep_parallel_mode |
||||
x, Var_x = ctx.saved_tensors |
||||
# in here, Var_x = 1 / sqrt(Var[x] + eps), x = (x - E[x]) * Var_x |
||||
with torch.no_grad(): |
||||
output_grad_sum = torch.sum(output_grad, dim=-1, keepdim=True) |
||||
torch.distributed.all_reduce( |
||||
output_grad_sum, group=get_parallel_group(row_parallel_mode)) |
||||
output_grad_sum /= ctx.hidden_size |
||||
|
||||
output_grad_mul_x_sum = torch.sum( |
||||
output_grad * x, dim=-1, keepdim=True) |
||||
torch.distributed.all_reduce( |
||||
output_grad_mul_x_sum, group=get_parallel_group(row_parallel_mode)) |
||||
output_grad_mul_x_sum /= ctx.hidden_size |
||||
|
||||
input_grad = output_grad.clone() |
||||
input_grad -= x * output_grad_mul_x_sum |
||||
input_grad -= output_grad_sum |
||||
input_grad *= Var_x |
||||
|
||||
return input_grad, None, None, None, None, None, None |
||||
|
||||
|
||||
class Sum_2p5D(torch.autograd.Function): |
||||
"""Compute the sum of input tensors |
||||
""" |
||||
|
||||
@staticmethod |
||||
def forward(ctx, |
||||
inputs, |
||||
dim, |
||||
tesseract_dim, |
||||
row_parallel_mode, |
||||
keepdim=False): |
||||
# input: [b/q, s, h/q] |
||||
empty_cache() |
||||
ctx.save_for_backward(inputs) |
||||
# sum: [b/q, s] |
||||
out = torch.sum(inputs, dim=dim, keepdim=keepdim) |
||||
torch.distributed.all_reduce( |
||||
out, group=gpc.get_group(row_parallel_mode)) |
||||
return out |
||||
|
||||
@staticmethod |
||||
def backward(ctx, output_grad): |
||||
with torch.no_grad(): |
||||
inputs = ctx.saved_tensors |
||||
input_grad = torch.ones(inputs.shape, dtype=output_grad.dtype) |
||||
return input_grad, None, None, None, None, None |
||||
|
||||
|
||||
class _ViT_Split_2p5D(torch.autograd.Function): |
||||
@staticmethod |
||||
def forward(ctx, inputs, batch_size, |
||||
tesseract_dim, tesseract_dep, |
||||
xz_parallel_mode): |
||||
# inputs: [b, s, h/q] |
||||
# output: [b/dq, s, h/q] |
||||
empty_cache() |
||||
|
||||
ctx.batch_size = batch_size |
||||
ctx.tesseract_dim = tesseract_dim |
||||
ctx.tesseract_dep = tesseract_dep |
||||
ctx.xz_parallel_mode = xz_parallel_mode |
||||
xz_rank = gpc.get_local_rank(xz_parallel_mode) |
||||
output = torch.chunk(inputs, tesseract_dep * |
||||
tesseract_dim, dim=0)[xz_rank] |
||||
output = output.clone() |
||||
return output |
||||
|
||||
@staticmethod |
||||
def backward(ctx, output_grad): |
||||
# output_grad: [b/dq, s, h/q] |
||||
# grads: [b, s, h/q] |
||||
# * |
||||
grads_shape = (ctx.batch_size,) + output_grad.shape[1:] |
||||
grads = torch.empty(grads_shape, |
||||
dtype=output_grad.dtype, |
||||
device=get_current_device()) |
||||
dist.all_gather(list(grads.chunk(ctx.tesseract_dim * ctx.tesseract_dep, dim=0)), |
||||
output_grad.contiguous(), |
||||
group=get_parallel_group(ctx.xz_parallel_mode)) |
||||
return grads, None, None, None, None |
@ -0,0 +1,206 @@
|
||||
#!/usr/bin/env python |
||||
# -*- encoding: utf-8 -*- |
||||
|
||||
import math |
||||
|
||||
import torch |
||||
from torch import nn as nn, Tensor |
||||
|
||||
from colossalai.nn.layer._common_utils import divide |
||||
from colossalai.registry import LAYERS |
||||
from ._utils import assert_tesseract_initialization, \ |
||||
get_tesseract_dim_dep_from_env |
||||
from .layers import Linear2p5D, LayerNorm2p5D |
||||
from .._common_utils import ACT2FN |
||||
|
||||
|
||||
@LAYERS.register_module |
||||
class TransformerMLP2p5D(nn.Module): |
||||
""" |
||||
MLP will take the input with h hidden state, project it to mlp_ratio * h |
||||
hidden dimension, perform nonlinear transformation, and project the |
||||
state back into h hidden dimension. At the end, dropout is also |
||||
applied. |
||||
|
||||
:param in_features: the size of input tensor |
||||
:type in_features: int |
||||
:param mlp_ratio: hidden size of MLP divided by embedding dim, defaults to 4.0 |
||||
:type mlp_ratio: int, optional |
||||
:param act_func: activation function, defaults to 'gelu' |
||||
:type act_func: str, optional |
||||
:param dropout_prob: dropout probability, defaults to 0. |
||||
:type dropout_prob: float, optional |
||||
:param dtype: dtype of parameters, defaults to None |
||||
:type dtype: torch.dtype, optional |
||||
""" |
||||
|
||||
def __init__(self, |
||||
in_features: int, |
||||
mlp_ratio: int, |
||||
act_func: str = 'gelu', |
||||
dropout_prob: float = 0., |
||||
dtype=None, |
||||
): |
||||
super().__init__() |
||||
assert_tesseract_initialization() |
||||
self.tesseract_dim, self.tesseract_dep = get_tesseract_dim_dep_from_env() |
||||
self.in_features = in_features |
||||
|
||||
# Project to h * mlp_ratio. |
||||
self.dense_1 = Linear2p5D( |
||||
in_features, |
||||
mlp_ratio * in_features, |
||||
dtype=dtype |
||||
) |
||||
|
||||
assert act_func in ACT2FN.keys(), f'Invalid value for argument act_func, ' \ |
||||
f'activation function can only be {list(ACT2FN.keys())}' |
||||
self.activation_func = ACT2FN[act_func] |
||||
|
||||
# Project back to h. |
||||
self.dense_2 = Linear2p5D( |
||||
mlp_ratio * in_features, |
||||
in_features, |
||||
dtype=dtype |
||||
) |
||||
self.dropout = nn.Dropout(dropout_prob) |
||||
self.layernorm = LayerNorm2p5D(in_features, dtype=dtype) |
||||
|
||||
def forward(self, x: Tensor) -> Tensor: |
||||
intermediate_output = self.dense_1(x) |
||||
intermediate_output = self.activation_func(intermediate_output) |
||||
output = self.dense_2(intermediate_output) |
||||
output = self.dropout(output) |
||||
output = self.layernorm(x + output) |
||||
return output |
||||
|
||||
|
||||
@LAYERS.register_module |
||||
class TransformerSelfAttention2p5D(nn.Module): |
||||
"""Self attention layer for 2.5D parallel Transformer |
||||
|
||||
:param hidden_size: hidden size |
||||
:type hidden_size: int |
||||
:param num_attention_heads: number of attention heads |
||||
:type num_attention_heads: int |
||||
:param attention_dropout_prob: dropout probability for attention layer |
||||
:type attention_dropout_prob: float |
||||
:param hidden_dropout_prob: dropout probability for hidden layer |
||||
:type hidden_dropout_prob: float |
||||
:param dtype: dtype of parameters, defaults to None |
||||
:type dtype: torch.dtype, optional |
||||
""" |
||||
|
||||
def __init__(self, |
||||
hidden_size, |
||||
num_attention_heads, |
||||
attention_dropout_prob, |
||||
hidden_dropout_prob, |
||||
dtype=None, |
||||
): |
||||
super().__init__() |
||||
|
||||
assert_tesseract_initialization() |
||||
self.tesseract_dim, self.tesseract_dep = get_tesseract_dim_dep_from_env() |
||||
self.hidden_size = hidden_size |
||||
self.num_attention_heads = divide( |
||||
num_attention_heads, self.tesseract_dim) # * |
||||
self.attention_head_size = divide(hidden_size, num_attention_heads) |
||||
self.all_head_size = self.num_attention_heads * self.attention_head_size |
||||
|
||||
self.query_key_value = Linear2p5D( |
||||
hidden_size, |
||||
3 * hidden_size, |
||||
dtype=dtype, |
||||
) |
||||
self.attention_dropout = nn.Dropout(attention_dropout_prob) |
||||
self.dense = Linear2p5D( |
||||
hidden_size, |
||||
hidden_size, |
||||
dtype=dtype, |
||||
) |
||||
self.dropout = nn.Dropout(hidden_dropout_prob) |
||||
self.layernorm = LayerNorm2p5D( |
||||
hidden_size, |
||||
dtype=dtype) |
||||
|
||||
def forward(self, hidden_states: Tensor, attention_mask: Tensor) -> Tensor: |
||||
query_key_value = self.query_key_value(hidden_states) |
||||
new_qkv_shape = query_key_value.shape[:-1] + \ |
||||
(self.num_attention_heads, 3 * self.attention_head_size) |
||||
query_key_value = query_key_value.view(new_qkv_shape) |
||||
query_key_value = query_key_value.permute((0, 2, 1, 3)) |
||||
query_layer, key_layer, value_layer = torch.chunk( |
||||
query_key_value, 3, dim=-1) |
||||
|
||||
attention_scores = torch.matmul( |
||||
query_layer, key_layer.transpose(-1, -2)) |
||||
attention_scores = attention_scores / \ |
||||
math.sqrt(self.attention_head_size) |
||||
attention_scores = attention_scores + attention_mask |
||||
attention_probs = nn.Softmax(dim=-1)(attention_scores) |
||||
attention_probs = self.attention_dropout(attention_probs) |
||||
|
||||
context_layer = torch.matmul(attention_probs, value_layer) |
||||
context_layer = context_layer.permute((0, 2, 1, 3)).contiguous() |
||||
new_context_layer_shape = context_layer.size()[ |
||||
:-2] + (self.all_head_size,) |
||||
context_layer = context_layer.view(*new_context_layer_shape) |
||||
|
||||
output = self.dense(context_layer) |
||||
output = self.dropout(output) |
||||
attention_output = self.layernorm(hidden_states + output) |
||||
|
||||
return attention_output |
||||
|
||||
|
||||
@LAYERS.register_module |
||||
class TransformerLayer2p5D(nn.Module): |
||||
"""Transformer layer which contains a self-attention layer and a MLP layer |
||||
|
||||
:param hidden_size: hidden size |
||||
:type hidden_size: int |
||||
:param num_attention_heads: number of attention heads |
||||
:type num_attention_heads: int |
||||
:param act_func: activation function, defaults to 'gelu' |
||||
:type act_func: str, optional |
||||
:param mlp_ratio: hidden size of MLP divided by embedding dim, defaults to 4.0 |
||||
:type mlp_ratio: float, optional |
||||
:param attention_dropout_prob: dropout probability for attention layer, defaults to 0. |
||||
:type attention_dropout_prob: float, optional |
||||
:param hidden_dropout_prob: dropout probability for attention layer, defaults to 0. |
||||
:type hidden_dropout_prob: float, optional |
||||
:param dtype: dtype of parameters, defaults to None |
||||
:type dtype: torch.dtype, optional |
||||
""" |
||||
|
||||
def __init__(self, |
||||
hidden_size, |
||||
num_attention_heads, |
||||
act_func='gelu', |
||||
mlp_ratio=4, |
||||
attention_dropout_prob: float = 0., |
||||
hidden_dropout_prob: float = 0., |
||||
dtype=None, |
||||
): |
||||
super().__init__() |
||||
|
||||
self.attention = TransformerSelfAttention2p5D( |
||||
hidden_size=hidden_size, |
||||
num_attention_heads=num_attention_heads, |
||||
attention_dropout_prob=attention_dropout_prob, |
||||
hidden_dropout_prob=hidden_dropout_prob, |
||||
dtype=dtype, |
||||
) |
||||
self.mlp = TransformerMLP2p5D( |
||||
in_features=hidden_size, |
||||
dropout_prob=hidden_dropout_prob, |
||||
act_func=act_func, |
||||
mlp_ratio=mlp_ratio, |
||||
dtype=dtype, |
||||
) |
||||
|
||||
def forward(self, hidden_states: Tensor, attention_mask: Tensor) -> Tensor: |
||||
attention_output = self.attention(hidden_states, attention_mask) |
||||
output = self.mlp(attention_output) |
||||
return output |
@ -0,0 +1,25 @@
|
||||
import os |
||||
|
||||
from colossalai.context.parallel_mode import ParallelMode |
||||
from colossalai.core import global_context as gpc |
||||
|
||||
|
||||
def get_tesseract_dim_dep_from_env(): |
||||
try: |
||||
tesseract_dim = int(os.environ['TESSERACT_DIM']) |
||||
tesseract_dep = int(os.environ['TESSERACT_DEP']) |
||||
assert tesseract_dim > 0, 'TESSERACT_DIM must be larger than zero' |
||||
assert tesseract_dep > 0, 'TESSERACT_DEP must be larger than zero' |
||||
return tesseract_dim, tesseract_dep |
||||
|
||||
except KeyError as e: |
||||
raise EnvironmentError('TESSERACT_DIM or TESSERACT_DEP is not found in the current environment, ' |
||||
'please make sure that you have used the correct process group initializer') |
||||
|
||||
|
||||
def assert_tesseract_initialization(): |
||||
assert gpc.is_initialized(ParallelMode.PARALLEL_2P5D_COL) and \ |
||||
gpc.is_initialized(ParallelMode.PARALLEL_2P5D_ROW) and \ |
||||
gpc.is_initialized(ParallelMode.PARALLEL_2P5D_DEP) and \ |
||||
gpc.is_initialized(ParallelMode.PARALLEL_2P5D_XZ), \ |
||||
'Both PARALLEL_2P5D_COL, PARALLEL_2P5D_ROW, PARALLEL_2P5D_DEP and PARALLEL_2P5D_XZ must be initialized by the process group initializer' |
@ -0,0 +1,351 @@
|
||||
#!/usr/bin/env python |
||||
# -*- encoding: utf-8 -*- |
||||
|
||||
import math |
||||
|
||||
import torch |
||||
from torch import nn as nn, Tensor, distributed as dist |
||||
|
||||
from colossalai.context.parallel_mode import ParallelMode |
||||
from colossalai.core import global_context as gpc |
||||
from colossalai.nn.layer.vanilla_vision_transformer.layers import to_2tuple |
||||
from colossalai.registry import LAYERS |
||||
from colossalai.utils import get_current_device |
||||
from ._operation import _ViT_Split_2p5D |
||||
from ._utils import assert_tesseract_initialization, \ |
||||
get_tesseract_dim_dep_from_env |
||||
from .layers import Linear2p5D |
||||
from .._common_utils import ACT2FN, divide, CheckpointModule |
||||
from .._common_utils import set_tensor_parallel_attribute |
||||
|
||||
|
||||
@LAYERS.register_module |
||||
class ViTMLP2p5D(CheckpointModule): |
||||
"""MLP layer for 2.5D parallel Vision Transformer |
||||
|
||||
:param in_features: size of each input sample |
||||
:type in_features: int |
||||
:param mlp_ratio: hidden size of MLP divided by embedding dim |
||||
:type mlp_ratio: int |
||||
:param act_func: activation function, defaults to 'gelu' |
||||
:type act_func: str, optional |
||||
:param dropout_prob: dropout probability, defaults to 0. |
||||
:type dropout_prob: float, optional |
||||
:param dtype: The dtype of parameters, defaults to None |
||||
:type dtype: torch.dtype, optional |
||||
:param checkpoint: If set to `True`, activation checkpoint is used, defaults to `False` |
||||
:type checkpoint: bool, optional |
||||
""" |
||||
|
||||
def __init__(self, |
||||
in_features: int, |
||||
mlp_ratio: int, |
||||
act_func: str = 'gelu', |
||||
dropout_prob: float = 0., |
||||
dtype=None, |
||||
checkpoint: bool = False |
||||
): |
||||
super().__init__(checkpoint=checkpoint) |
||||
|
||||
assert_tesseract_initialization() |
||||
self.in_features = in_features |
||||
self.mlp_ratio = mlp_ratio |
||||
|
||||
# Project to mlp_ratio * h. |
||||
self.dense_1 = Linear2p5D( |
||||
self.in_features, |
||||
self.mlp_ratio * self.in_features, |
||||
dtype=dtype, |
||||
) |
||||
|
||||
self.act = ACT2FN[act_func] |
||||
|
||||
# Project back to h. |
||||
self.dense_2 = Linear2p5D( |
||||
self.mlp_ratio * self.in_features, |
||||
self.in_features, |
||||
dtype=dtype, |
||||
) |
||||
self.dropout = nn.Dropout(dropout_prob) |
||||
|
||||
def _forward(self, hidden_states: Tensor) -> Tensor: |
||||
intermediate_output = self.dense_1(hidden_states) |
||||
intermediate_output = self.act(intermediate_output) |
||||
intermediate_output = self.dropout(intermediate_output) |
||||
output = self.dense_2(intermediate_output) |
||||
output = self.dropout(output) |
||||
return output |
||||
|
||||
|
||||
@LAYERS.register_module |
||||
class ViTSelfAttention2p5D(CheckpointModule): |
||||
"""Self-attention layer for 2.5D parallel Vision Transformer |
||||
|
||||
:param hidden_size: hidden size |
||||
:type hidden_size: int |
||||
:param num_attention_heads: number of attention heads |
||||
:type num_attention_heads: int |
||||
:param attention_dropout_prob: dropout probability for attention layers |
||||
:type attention_dropout_prob: float |
||||
:param hidden_dropout_prob: dropout probability for hidden layers |
||||
:type hidden_dropout_prob: float |
||||
:param dtype: dtype of parameters, defaults to None |
||||
:type dtype: torch.dtype, optional |
||||
:param checkpoint: If set to `True`, activation checkpoint is used, defaults to `False` |
||||
:type checkpoint: bool, optional |
||||
""" |
||||
|
||||
def __init__(self, |
||||
hidden_size, |
||||
num_attention_heads, |
||||
attention_dropout_prob, |
||||
hidden_dropout_prob, |
||||
dtype=None, |
||||
checkpoint: bool = False |
||||
): |
||||
super().__init__(checkpoint=checkpoint) |
||||
|
||||
assert_tesseract_initialization() |
||||
self.tesseract_dim, self.tesseract_dep = get_tesseract_dim_dep_from_env() |
||||
self.hidden_size = hidden_size |
||||
self.num_attention_heads = divide( |
||||
num_attention_heads, self.tesseract_dim) # * |
||||
self.attention_head_size = divide(hidden_size, num_attention_heads) |
||||
self.all_head_size = self.num_attention_heads * self.attention_head_size |
||||
|
||||
self.query_key_value = Linear2p5D( |
||||
hidden_size, |
||||
3 * hidden_size, |
||||
dtype=dtype, |
||||
) |
||||
self.attention_dropout = nn.Dropout(attention_dropout_prob) |
||||
self.dense = Linear2p5D( |
||||
hidden_size, |
||||
hidden_size, |
||||
dtype=dtype, |
||||
) |
||||
self.dropout = nn.Dropout(hidden_dropout_prob) |
||||
|
||||
def _forward(self, hidden_states: Tensor) -> Tensor: |
||||
query_key_value = self.query_key_value(hidden_states) |
||||
new_qkv_shape = query_key_value.shape[:-1] + \ |
||||
(self.num_attention_heads, 3 * self.attention_head_size) |
||||
query_key_value = query_key_value.view(new_qkv_shape) |
||||
query_key_value = query_key_value.permute((0, 2, 1, 3)) |
||||
query_layer, key_layer, value_layer = torch.chunk( |
||||
query_key_value, 3, dim=-1) |
||||
|
||||
attention_scores = torch.matmul( |
||||
query_layer, key_layer.transpose(-1, -2)) |
||||
attention_scores = attention_scores / \ |
||||
math.sqrt(self.attention_head_size) |
||||
|
||||
attention_probs = nn.Softmax(dim=-1)(attention_scores) |
||||
attention_probs = self.attention_dropout(attention_probs) |
||||
|
||||
context_layer = torch.matmul(attention_probs, value_layer) |
||||
context_layer = context_layer.transpose(1, 2) |
||||
new_context_layer_shape = context_layer.size()[ |
||||
:-2] + (self.all_head_size,) |
||||
context_layer = context_layer.reshape(new_context_layer_shape) |
||||
|
||||
output = self.dense(context_layer) |
||||
output = self.dropout(output) |
||||
return output |
||||
|
||||
|
||||
@LAYERS.register_module |
||||
class ViTHead2p5D(nn.Module): |
||||
"""Output layer for 2.5D parallel Vision Transformer |
||||
|
||||
:param hidden_size: hidden size |
||||
:type hidden_size: int |
||||
:param num_classes: number of classes |
||||
:type num_classes: int |
||||
:param dtype: dtype of parameters, defaults to None |
||||
:type dtype: torch.dtype, optional |
||||
""" |
||||
|
||||
def __init__(self, |
||||
hidden_size, |
||||
num_classes, |
||||
dtype=None, |
||||
): |
||||
super().__init__() |
||||
assert_tesseract_initialization() |
||||
self.linear = Linear2p5D( |
||||
hidden_size, |
||||
num_classes, |
||||
dtype=dtype, |
||||
) |
||||
|
||||
def forward(self, x: Tensor) -> Tensor: |
||||
x = x[:, 0] |
||||
x = self.linear(x) |
||||
return x |
||||
|
||||
|
||||
@LAYERS.register_module |
||||
class ViTPatchEmbedding2p5D(nn.Module): |
||||
""" 2.5D Image to Patch Embedding |
||||
|
||||
:param img_size: iamge size |
||||
:type img_size: int |
||||
:param patch_size: patch size |
||||
:type patch_size: int |
||||
:param embed_dim: dimension of embedding |
||||
:type embed_dim: int |
||||
:param in_chans: number of channels of input image, defaults to 3 |
||||
:type in_chans: int, optional |
||||
:param flatten: whether to flatten output tensor, defaults to True |
||||
:type flatten: bool, optional |
||||
""" |
||||
|
||||
def __init__(self, |
||||
img_size, |
||||
patch_size, |
||||
embed_dim, |
||||
in_chans=3, |
||||
flatten=True): |
||||
super().__init__() |
||||
img_size = to_2tuple(img_size) |
||||
patch_size = to_2tuple(patch_size) |
||||
|
||||
assert_tesseract_initialization() |
||||
self.tesseract_dim, self.tesseract_dep = get_tesseract_dim_dep_from_env() |
||||
self.img_size = img_size |
||||
self.patch_size = patch_size |
||||
self.grid_size = (img_size[0] // patch_size[0], |
||||
img_size[1] // patch_size[1]) |
||||
self.num_patches = self.grid_size[0] * self.grid_size[1] |
||||
self.flatten = flatten |
||||
self.embed_dim = embed_dim // self.tesseract_dim # * |
||||
|
||||
self.proj = nn.Conv2d(in_chans, |
||||
self.embed_dim, |
||||
kernel_size=patch_size, |
||||
stride=patch_size, |
||||
) |
||||
|
||||
# move self to cuda before sync |
||||
self.to(get_current_device()) |
||||
|
||||
# sync |
||||
self._broadcast_conv_params() |
||||
self.proj.weight.register_hook(self._sync_grad_during_backward) |
||||
self.proj.bias.register_hook(self._sync_grad_during_backward) |
||||
|
||||
def _broadcast_conv_params(self) -> None: |
||||
xz_rank = gpc.get_ranks_in_group(ParallelMode.PARALLEL_2P5D_XZ) |
||||
dist.broadcast(self.proj.weight, src=xz_rank[0], |
||||
group=gpc.get_group(ParallelMode.PARALLEL_2P5D_XZ)) |
||||
dist.broadcast(self.proj.bias, src=xz_rank[0], |
||||
group=gpc.get_group(ParallelMode.PARALLEL_2P5D_XZ)) |
||||
|
||||
def _sync_grad_during_backward(self, grad: Tensor) -> None: |
||||
dist.all_reduce(grad, group=gpc.get_group( |
||||
ParallelMode.PARALLEL_2P5D_XZ)) |
||||
grad = grad / self.tesseract_dim / self.tesseract_dep # * |
||||
return grad |
||||
|
||||
def forward(self, x: Tensor) -> Tensor: |
||||
B, C, H, W = x.shape |
||||
assert H == self.img_size[0] and W == self.img_size[1], \ |
||||
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." |
||||
x = self.proj(x) |
||||
if self.flatten: |
||||
x = x.flatten(2).transpose(1, 2) # BCHW -> BNC |
||||
return x |
||||
|
||||
|
||||
@LAYERS.register_module |
||||
class ViTTokenFuser2p5D(nn.Module): |
||||
""" |
||||
Fuse cls token and pos embedding to the input |
||||
|
||||
:param img_size: image size |
||||
:type img_size: int |
||||
:param patch_size: patch size |
||||
:type patch_size: int |
||||
:param embed_dim: dimension of embedding |
||||
:type embed_dim: int |
||||
:param drop_rate: dropout probability, defaults to 0. |
||||
:type drop_rate: float, optional |
||||
""" |
||||
|
||||
def __init__(self, |
||||
img_size, |
||||
patch_size, |
||||
embed_dim, |
||||
drop_rate=0. |
||||
): |
||||
super().__init__() |
||||
img_size = to_2tuple(img_size) |
||||
patch_size = to_2tuple(patch_size) |
||||
|
||||
assert_tesseract_initialization() |
||||
self.tesseract_dim, self.tesseract_dep = get_tesseract_dim_dep_from_env() |
||||
self.img_size = img_size |
||||
self.patch_size = patch_size |
||||
self.grid_size = (img_size[0] // patch_size[0], |
||||
img_size[1] // patch_size[1]) |
||||
self.num_patches = self.grid_size[0] * self.grid_size[1] |
||||
self.embed_dim = embed_dim |
||||
|
||||
self.cls_token = nn.Parameter(torch.zeros( |
||||
1, 1, self.embed_dim // self.tesseract_dim)) # * |
||||
self.pos_embed = nn.Parameter(torch.zeros( |
||||
1, self.num_patches + 1, self.embed_dim // self.tesseract_dim)) # * |
||||
|
||||
# move to cuda before broadcast |
||||
self.to(get_current_device()) |
||||
|
||||
self._broadcast_params() |
||||
self.cls_token.register_hook(self._sync_grad_hook) |
||||
self.pos_embed.register_hook(self._sync_grad_hook) |
||||
self.pos_drop = nn.Dropout(p=drop_rate) |
||||
self._set_tensor_parallel_attribute() |
||||
|
||||
def _set_tensor_parallel_attribute(self): |
||||
set_tensor_parallel_attribute(self.cls_token) |
||||
set_tensor_parallel_attribute(self.pos_embed) |
||||
|
||||
def _broadcast_params(self) -> None: |
||||
" broadcast to all column ranks for data consistency " |
||||
xz_rank = gpc.get_ranks_in_group(ParallelMode.PARALLEL_2P5D_XZ) |
||||
dist.broadcast(self.cls_token, src=xz_rank[0], |
||||
group=gpc.get_group(ParallelMode.PARALLEL_2P5D_XZ)) |
||||
dist.broadcast(self.pos_embed, src=xz_rank[0], |
||||
group=gpc.get_group(ParallelMode.PARALLEL_2P5D_XZ)) |
||||
|
||||
def _sync_grad_hook(self, grad) -> None: |
||||
dist.all_reduce(grad, group=gpc.get_group( |
||||
ParallelMode.PARALLEL_2P5D_XZ)) |
||||
grad = grad / self.tesseract_dim / self.tesseract_dep # * |
||||
return grad |
||||
|
||||
def forward(self, x: Tensor) -> Tensor: |
||||
# stole cls_tokens impl from Phil Wang, thanks |
||||
cls_token = self.cls_token.expand(x.shape[0], -1, -1) |
||||
x = torch.cat((cls_token, x), dim=1) |
||||
x = self.pos_drop(x + self.pos_embed) |
||||
return x |
||||
|
||||
|
||||
@LAYERS.register_module |
||||
class ViTInputSplitter2p5D(nn.Module): |
||||
|
||||
def __init__(self): |
||||
super().__init__() |
||||
assert_tesseract_initialization() |
||||
self.tesseract_dim, self.tesseract_dep = get_tesseract_dim_dep_from_env() |
||||
|
||||
def forward(self, x: Tensor) -> Tensor: |
||||
batch_size = x.size(0) |
||||
return _ViT_Split_2p5D.apply( |
||||
x, |
||||
batch_size, |
||||
self.tesseract_dim, |
||||
self.tesseract_dep, |
||||
ParallelMode.PARALLEL_2P5D_XZ, |
||||
) |
@ -0,0 +1,266 @@
|
||||
import math |
||||
|
||||
import torch |
||||
from torch import Tensor |
||||
from torch.nn import Parameter, init as init |
||||
|
||||
from colossalai.context import seed, ParallelMode |
||||
from colossalai.core import global_context as gpc |
||||
from colossalai.registry import LAYERS |
||||
from colossalai.utils import get_current_device |
||||
from ._operation import Matmul_AB_2p5D, Add_Bias_2p5D, _LayerNorm_2p5D |
||||
from ._utils import get_tesseract_dim_dep_from_env, assert_tesseract_initialization |
||||
from .._common_utils import divide, set_tensor_parallel_attribute |
||||
from ..base_layer import ParallelLayer |
||||
|
||||
|
||||
@LAYERS.register_module |
||||
class Linear2p5D(ParallelLayer): |
||||
"""Linear layer for 2.5D parallelism |
||||
|
||||
:param in_features: size of each input sample |
||||
:type in_features: int |
||||
:param out_features: size of each output sample |
||||
:type out_features: int |
||||
:param bias: If set to ``False``, the layer will not learn an additive bias, defaults to True |
||||
:type bias: bool, optional |
||||
:param dtype: The dtype of parameters, defaults to None |
||||
:type dtype: torch.dtype, optional |
||||
""" |
||||
|
||||
def __init__(self, |
||||
in_features: int, |
||||
out_features: int, |
||||
bias: bool = True, |
||||
dtype=None, |
||||
skip_bias_add: bool = False |
||||
): |
||||
super().__init__() |
||||
|
||||
self.in_features = in_features |
||||
self.out_features = out_features |
||||
self.skip_bias_add = skip_bias_add |
||||
|
||||
# parallel setting |
||||
assert_tesseract_initialization() |
||||
self.row_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) |
||||
self.col_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW) |
||||
self.dep_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP) |
||||
self.tesseract_dim, self.tesseract_dep = get_tesseract_dim_dep_from_env() |
||||
|
||||
# partitioning dimension |
||||
self.input_size_per_partition = divide(in_features, self.tesseract_dim) |
||||
self.hidden_size_per_partition = divide( |
||||
out_features, self.tesseract_dim) |
||||
|
||||
# create weight, shape: [k/q, h/q] |
||||
factory_kwargs = {'device': get_current_device(), 'dtype': dtype} |
||||
self.weight = Parameter(torch.empty( |
||||
self.input_size_per_partition, |
||||
self.hidden_size_per_partition, |
||||
**factory_kwargs)) |
||||
|
||||
# create bias, shape: [h/q] |
||||
if bias: |
||||
self.bias = Parameter(torch.empty( |
||||
self.hidden_size_per_partition, |
||||
**factory_kwargs)) |
||||
else: |
||||
self.register_parameter('bias', None) |
||||
|
||||
# initialize parameters |
||||
self.reset_parameters() |
||||
self._set_tensor_parallel_attributes() |
||||
|
||||
def _set_tensor_parallel_attributes(self): |
||||
set_tensor_parallel_attribute(self.weight) |
||||
if self.bias is not None: |
||||
set_tensor_parallel_attribute(self.bias) |
||||
|
||||
def reset_parameters(self) -> None: |
||||
# setting |
||||
fan_in = self.in_features |
||||
a = math.sqrt(5) |
||||
nonlinearity = 'leaky_relu' |
||||
|
||||
# init weight |
||||
std = init.calculate_gain(nonlinearity, a) / math.sqrt(fan_in) |
||||
bound = math.sqrt(3.0) * std |
||||
with seed(ParallelMode.TENSOR): |
||||
init.uniform_(self.weight, -bound, bound) |
||||
|
||||
# init bias |
||||
if self.bias is not None: |
||||
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 |
||||
with seed(ParallelMode.TENSOR): |
||||
init.uniform_(self.bias, -bound, bound) |
||||
|
||||
def forward(self, x: Tensor) -> Tensor: |
||||
# input: [m/dq, n/q, k/q] |
||||
# output: [m/dq, n/q, h/q] |
||||
out_shape = x.shape[:-1] + (self.hidden_size_per_partition,) |
||||
output = Matmul_AB_2p5D.apply( |
||||
x, |
||||
self.weight, |
||||
self.tesseract_dim, |
||||
self.tesseract_dep, |
||||
out_shape, |
||||
self.row_rank, self.col_rank, self.dep_rank, |
||||
ParallelMode.PARALLEL_2P5D_ROW, |
||||
ParallelMode.PARALLEL_2P5D_COL, |
||||
ParallelMode.PARALLEL_2P5D_DEP, |
||||
self.data_parallel_rank, |
||||
self.pipeline_parallel_rank, |
||||
self.pipeline_parallel_size, |
||||
self.tensor_parallel_size, |
||||
) |
||||
|
||||
if self.bias is not None: |
||||
if self.skip_bias_add: |
||||
bias = Add_Bias_2p5D.apply( |
||||
None, |
||||
self.bias, |
||||
self.hidden_size_per_partition, |
||||
self.tesseract_dim, self.tesseract_dep, |
||||
self.row_rank, self.col_rank, self.dep_rank, |
||||
ParallelMode.PARALLEL_2P5D_ROW, |
||||
ParallelMode.PARALLEL_2P5D_COL, |
||||
ParallelMode.PARALLEL_2P5D_DEP, |
||||
True, |
||||
self.data_parallel_rank, |
||||
self.pipeline_parallel_rank, |
||||
self.pipeline_parallel_size, |
||||
self.tensor_parallel_size |
||||
) |
||||
return output, bias |
||||
else: |
||||
output = Add_Bias_2p5D.apply( |
||||
output, |
||||
self.bias, |
||||
self.hidden_size_per_partition, |
||||
self.tesseract_dim, self.tesseract_dep, |
||||
self.row_rank, self.col_rank, self.dep_rank, |
||||
ParallelMode.PARALLEL_2P5D_ROW, |
||||
ParallelMode.PARALLEL_2P5D_COL, |
||||
ParallelMode.PARALLEL_2P5D_DEP, |
||||
False, |
||||
self.data_parallel_rank, |
||||
self.pipeline_parallel_rank, |
||||
self.pipeline_parallel_size, |
||||
self.tensor_parallel_size |
||||
) |
||||
return output |
||||
else: |
||||
return output |
||||
|
||||
|
||||
@LAYERS.register_module |
||||
class LayerNorm2p5D(ParallelLayer): |
||||
r"""Layer Normalization for 2.5D parallelism |
||||
|
||||
:param normalized_shape: input shape from an expected input |
||||
of size. :math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1] \times \ldots \times \text{normalized_shape}[-1]]` |
||||
If a single integer is used, it is treated as a singleton list, and this module will |
||||
normalize over the last dimension which is expected to be of that specific size. |
||||
:type normalized_shape: int |
||||
:param eps: a value added to the denominator for numerical stability, defaults to 1e-05 |
||||
:type eps: float, optional |
||||
:param dtype: The dtype of parameters, defaults to None |
||||
:type dtype: torch.dtype, optional |
||||
""" |
||||
def __init__(self, |
||||
normalized_shape: int, |
||||
eps: float = 1e-05, |
||||
dtype=None |
||||
): |
||||
super().__init__() |
||||
|
||||
# layer norm config |
||||
self.normalized_shape = normalized_shape |
||||
self.variance_epsilon = eps |
||||
|
||||
# parallel setting |
||||
assert_tesseract_initialization() |
||||
self.row_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) |
||||
self.col_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW) |
||||
self.dep_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP) |
||||
self.tesseract_dim, self.tesseract_dep = get_tesseract_dim_dep_from_env() |
||||
|
||||
# partitioning dimension |
||||
self.partitioned_partition = divide( |
||||
normalized_shape, self.tesseract_dim) # * |
||||
|
||||
# create parameters |
||||
factory_kwargs = {'device': get_current_device(), 'dtype': dtype} |
||||
|
||||
if self.row_rank == 0: |
||||
self.gamma = Parameter(torch.ones( |
||||
self.partitioned_partition, |
||||
**factory_kwargs)) |
||||
self.beta = Parameter(torch.zeros( |
||||
self.partitioned_partition, |
||||
**factory_kwargs)) |
||||
else: |
||||
self.gamma = Parameter(torch.tensor( |
||||
1.0, |
||||
requires_grad=True, |
||||
**factory_kwargs)) |
||||
self.beta = Parameter(torch.tensor( |
||||
1.0, |
||||
requires_grad=True, |
||||
**factory_kwargs)) |
||||
self._set_tensor_parallel_attribute() |
||||
|
||||
def _set_tensor_parallel_attribute(self): |
||||
set_tensor_parallel_attribute(self.gamma) |
||||
set_tensor_parallel_attribute(self.beta) |
||||
|
||||
def forward(self, x: Tensor) -> Tensor: |
||||
with torch.no_grad(): |
||||
E_x = torch.sum(x, dim=-1, keepdim=True) # [b/q, s, 1] |
||||
torch.distributed.all_reduce( |
||||
E_x, group=gpc.get_group(ParallelMode.PARALLEL_2P5D_ROW)) |
||||
E_x /= self.normalized_shape |
||||
|
||||
# Var_x in the block below is the sum of input^2 |
||||
Var_x = torch.sum(x * x, dim=-1, keepdim=True) # [b/q, s, 1] |
||||
torch.distributed.all_reduce( |
||||
Var_x, group=gpc.get_group(ParallelMode.PARALLEL_2P5D_ROW)) |
||||
Var_x /= self.normalized_shape |
||||
|
||||
Var_x = Var_x - E_x * E_x # variance of x [b/q, s, 1] |
||||
# this time 1/sqrt(Var_x + epsilon) |
||||
Var_x = 1.0 / torch.sqrt(Var_x + self.variance_epsilon) |
||||
|
||||
output = _LayerNorm_2p5D.apply(x, E_x, Var_x, self.normalized_shape, |
||||
ParallelMode.PARALLEL_2P5D_ROW, |
||||
ParallelMode.PARALLEL_2P5D_COL, |
||||
ParallelMode.PARALLEL_2P5D_DEP) |
||||
bias = Add_Bias_2p5D.apply( |
||||
None, self.beta, self.partitioned_partition, |
||||
self.tesseract_dim, self.tesseract_dep, |
||||
self.row_rank, self.col_rank, self.dep_rank, |
||||
ParallelMode.PARALLEL_2P5D_ROW, |
||||
ParallelMode.PARALLEL_2P5D_COL, |
||||
ParallelMode.PARALLEL_2P5D_DEP, |
||||
True, |
||||
self.data_parallel_rank, |
||||
self.pipeline_parallel_rank, |
||||
self.pipeline_parallel_size, |
||||
self.tensor_parallel_size |
||||
) |
||||
scale = Add_Bias_2p5D.apply( |
||||
None, self.gamma, self.partitioned_partition, |
||||
self.tesseract_dim, self.tesseract_dep, |
||||
self.row_rank, self.col_rank, self.dep_rank, |
||||
ParallelMode.PARALLEL_2P5D_ROW, |
||||
ParallelMode.PARALLEL_2P5D_COL, |
||||
ParallelMode.PARALLEL_2P5D_DEP, |
||||
True, |
||||
self.data_parallel_rank, |
||||
self.pipeline_parallel_rank, |
||||
self.pipeline_parallel_size, |
||||
self.tensor_parallel_size |
||||
) |
||||
output = torch.addcmul(bias, scale, output) |
||||
return output |
@ -0,0 +1,9 @@
|
||||
from ._operation import Matmul_ABT_3D, Matmul_ATB_3D, Matmul_AB_3D, Mul_3D, Sum_3D, Add_3D, Reduce_3D |
||||
from ._vit import ViTHead3D, ViTMLP3D, ViTPatchEmbedding3D, ViTSelfAttention3D |
||||
from .layers import Linear3D, LayerNorm3D |
||||
|
||||
__all__ = [ |
||||
'Matmul_ABT_3D', 'Matmul_ATB_3D', 'Matmul_AB_3D', 'Mul_3D', 'Sum_3D', 'Add_3D', 'Reduce_3D', |
||||
'ViTHead3D', 'ViTMLP3D', 'ViTPatchEmbedding3D', 'ViTSelfAttention3D', |
||||
'Linear3D', 'LayerNorm3D' |
||||
] |
@ -0,0 +1,349 @@
|
||||
#!/usr/bin/env python |
||||
# -*- encoding: utf-8 -*- |
||||
|
||||
from typing import Any, Tuple |
||||
|
||||
import torch |
||||
import torch.distributed as dist |
||||
from colossalai.communication import all_gather, reduce_scatter, scatter |
||||
from colossalai.context.parallel_mode import ParallelMode |
||||
from colossalai.core import global_context as gpc |
||||
from colossalai.utils import empty_cache, get_current_device |
||||
from torch import Tensor |
||||
|
||||
|
||||
class Matmul_AB_3D(torch.autograd.Function): |
||||
"""Matrix multiplication for :math:`C = AB` |
||||
""" |
||||
@staticmethod |
||||
def forward(ctx: Any, |
||||
A: Tensor, |
||||
B: Tensor, |
||||
depth: int, |
||||
input_parallel_mode: ParallelMode, |
||||
weight_parallel_mode: ParallelMode, |
||||
output_parallel_mode: ParallelMode, |
||||
input_dim: int = 0, |
||||
weight_dim: int = -1, |
||||
output_dim: int = 0) -> Tensor: |
||||
# A: [m/q^2, n, k/q] |
||||
# B: [k/q, h/q^2] |
||||
# C: [m/q^2, n, h/q] |
||||
empty_cache() |
||||
ctx.save_for_backward(A, B) |
||||
|
||||
assert A.shape[-1] == B.shape[0], \ |
||||
'Invalid shapes: A={}, B={}.'.format(A.shape, B.shape) |
||||
|
||||
A_temp = all_gather(A, input_dim, input_parallel_mode) |
||||
B_temp = all_gather(B, weight_dim, weight_parallel_mode) |
||||
|
||||
C = torch.matmul(A_temp, B_temp) |
||||
out = reduce_scatter(C, output_dim, output_parallel_mode) |
||||
|
||||
ctx.depth = depth |
||||
ctx.A_group_parallel_mode = input_parallel_mode |
||||
ctx.B_group_parallel_mode = weight_parallel_mode |
||||
ctx.C_group_parallel_mode = output_parallel_mode |
||||
ctx.A_dim = input_dim |
||||
ctx.B_dim = weight_dim |
||||
ctx.C_dim = output_dim |
||||
|
||||
return out |
||||
|
||||
@staticmethod |
||||
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]: |
||||
A, B = ctx.saved_tensors |
||||
with torch.no_grad(): |
||||
A_grad = Matmul_ABT_3D.apply(output_grad, B, ctx.depth, |
||||
ctx.C_group_parallel_mode, |
||||
ctx.B_group_parallel_mode, |
||||
ctx.A_group_parallel_mode, ctx.C_dim, |
||||
ctx.B_dim, ctx.A_dim) |
||||
B_grad = Matmul_ATB_3D.apply(A, output_grad, ctx.depth, |
||||
ctx.A_group_parallel_mode, |
||||
ctx.C_group_parallel_mode, |
||||
ctx.B_group_parallel_mode, ctx.A_dim, |
||||
ctx.C_dim, ctx.B_dim) |
||||
return A_grad, B_grad, None, None, None, None, None, None, None |
||||
|
||||
|
||||
class Matmul_ABT_3D(torch.autograd.Function): |
||||
"""Matrix multiplication for :math:`C = AB^T` |
||||
""" |
||||
@staticmethod |
||||
def forward(ctx: Any, |
||||
A: Tensor, |
||||
B: Tensor, |
||||
depth: int, |
||||
input_parallel_mode: ParallelMode, |
||||
weight_parallel_mode: ParallelMode, |
||||
output_parallel_mode: ParallelMode, |
||||
input_dim: int = 0, |
||||
weight_dim: int = -1, |
||||
output_dim: int = 0) -> Tensor: |
||||
# A: [m/q^2, n, h/q] |
||||
# B: [k/q, h/q^2] |
||||
# C: [m/q^2, n, k/q] |
||||
empty_cache() |
||||
ctx.save_for_backward(A, B) |
||||
|
||||
A_temp = all_gather(A, input_dim, input_parallel_mode) |
||||
B_temp = all_gather(B, weight_dim, weight_parallel_mode) |
||||
|
||||
C = torch.matmul(A_temp, B_temp.transpose(0, 1)) |
||||
out = reduce_scatter(C, output_dim, output_parallel_mode) |
||||
|
||||
ctx.depth = depth |
||||
ctx.A_group_parallel_mode = input_parallel_mode |
||||
ctx.B_group_parallel_mode = weight_parallel_mode |
||||
ctx.C_group_parallel_mode = output_parallel_mode |
||||
ctx.A_dim = input_dim |
||||
ctx.B_dim = weight_dim |
||||
ctx.C_dim = output_dim |
||||
|
||||
return out |
||||
|
||||
@staticmethod |
||||
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]: |
||||
A, B = ctx.saved_tensors |
||||
with torch.no_grad(): |
||||
A_grad = Matmul_AB_3D.apply(output_grad, B, ctx.depth, |
||||
ctx.C_group_parallel_mode, |
||||
ctx.B_group_parallel_mode, |
||||
ctx.A_group_parallel_mode, ctx.C_dim, |
||||
ctx.B_dim, ctx.A_dim) |
||||
B_grad = Matmul_ATB_3D.apply(output_grad, A, ctx.depth, |
||||
ctx.C_group_parallel_mode, |
||||
ctx.A_group_parallel_mode, |
||||
ctx.B_group_parallel_mode, ctx.C_dim, |
||||
ctx.A_dim, ctx.B_dim) |
||||
return A_grad, B_grad, None, None, None, None, None, None, None |
||||
|
||||
|
||||
class Matmul_ATB_3D(torch.autograd.Function): |
||||
"""Matrix multiplication for :math:`C = A^TB` |
||||
""" |
||||
@staticmethod |
||||
def forward(ctx: Any, |
||||
A: Tensor, |
||||
B: Tensor, |
||||
depth: int, |
||||
input_parallel_mode: ParallelMode, |
||||
weight_parallel_mode: ParallelMode, |
||||
output_parallel_mode: ParallelMode, |
||||
input_dim: int = 0, |
||||
weight_dim: int = 0, |
||||
output_dim: int = -1) -> Tensor: |
||||
# A: [m/q^2, n, k/q] |
||||
# B: [m/q^2, n, h/q] |
||||
# C: [k/q, h/q^2] |
||||
empty_cache() |
||||
ctx.save_for_backward(A, B) |
||||
|
||||
A_temp = all_gather(A, input_dim, input_parallel_mode) |
||||
A_temp = A_temp.reshape(-1, A.shape[-1]) |
||||
B_temp = all_gather(B, weight_dim, weight_parallel_mode) |
||||
B_temp = B_temp.reshape(-1, B.shape[-1]) |
||||
|
||||
C = torch.matmul(A_temp.transpose(0, 1), B_temp) |
||||
out = reduce_scatter(C, output_dim, output_parallel_mode) |
||||
|
||||
ctx.depth = depth |
||||
ctx.A_group_parallel_mode = input_parallel_mode |
||||
ctx.B_group_parallel_mode = weight_parallel_mode |
||||
ctx.C_group_parallel_mode = output_parallel_mode |
||||
ctx.A_dim = input_dim |
||||
ctx.B_dim = weight_dim |
||||
ctx.C_dim = output_dim |
||||
|
||||
return out |
||||
|
||||
@staticmethod |
||||
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]: |
||||
A, B = ctx.saved_tensors |
||||
with torch.no_grad(): |
||||
A_grad = Matmul_ABT_3D.apply(B, output_grad, ctx.depth, |
||||
ctx.B_group_parallel_mode, |
||||
ctx.C_group_parallel_mode, |
||||
ctx.A_group_parallel_mode, ctx.B_dim, |
||||
ctx.C_dim, ctx.A_dim) |
||||
B_grad = Matmul_AB_3D.apply(A, output_grad, ctx.depth, |
||||
ctx.A_group_parallel_mode, |
||||
ctx.C_group_parallel_mode, |
||||
ctx.B_group_parallel_mode, ctx.A_dim, |
||||
ctx.C_dim, ctx.B_dim) |
||||
return A_grad, B_grad, None, None, None, None, None, None, None |
||||
|
||||
|
||||
class Add_3D(torch.autograd.Function): |
||||
"""Matrix add bias: :math:`C = A + b` |
||||
""" |
||||
@staticmethod |
||||
def forward(ctx: Any, input_: Tensor, bias: Tensor, depth: int, |
||||
input_parallel_mode: ParallelMode, |
||||
weight_parallel_mode: ParallelMode, |
||||
output_parallel_mode: ParallelMode) -> Tensor: |
||||
# input: [m/q^2, n, h/q] |
||||
# bias: [h/q^2] |
||||
ranks_in_group = gpc.get_ranks_in_group(input_parallel_mode) |
||||
src_rank = ranks_in_group[gpc.get_local_rank(output_parallel_mode)] |
||||
bias_temp = bias.clone() |
||||
dist.broadcast(bias_temp, |
||||
src=src_rank, |
||||
group=gpc.get_group(input_parallel_mode)) |
||||
# [h/q] |
||||
bias_temp = all_gather(bias_temp, -1, weight_parallel_mode) |
||||
|
||||
out = input_ + bias_temp |
||||
|
||||
ctx.depth = depth |
||||
ctx.src_rank = src_rank |
||||
ctx.A_group_parallel_mode = input_parallel_mode |
||||
ctx.B_group_parallel_mode = weight_parallel_mode |
||||
ctx.C_group_parallel_mode = output_parallel_mode |
||||
|
||||
return out |
||||
|
||||
@staticmethod |
||||
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]: |
||||
# output_grad: [m/q^2, n, h/q] |
||||
with torch.no_grad(): |
||||
# [h/q] |
||||
grad = torch.sum(output_grad, |
||||
dim=tuple(range(len(output_grad.shape))[:-1])) |
||||
bias_grad = reduce_scatter(grad, -1, ctx.B_group_parallel_mode) |
||||
dist.reduce(bias_grad, |
||||
dst=ctx.src_rank, |
||||
group=gpc.get_group(ctx.A_group_parallel_mode)) |
||||
if gpc.get_local_rank( |
||||
ctx.A_group_parallel_mode) != gpc.get_local_rank( |
||||
ctx.C_group_parallel_mode): |
||||
bias_grad = None |
||||
return output_grad, bias_grad, None, None, None, None |
||||
|
||||
|
||||
class Mul_3D(torch.autograd.Function): |
||||
"""Matrix multiplication for :math:`C = A * b` |
||||
""" |
||||
@staticmethod |
||||
def forward(ctx: Any, input_: Tensor, bias: Tensor, depth: int, |
||||
input_parallel_mode: ParallelMode, |
||||
weight_parallel_mode: ParallelMode, |
||||
output_parallel_mode: ParallelMode) -> Tensor: |
||||
# input: [m/q^2, n, h/q] |
||||
# bias: [h/q^2] |
||||
ranks_in_group = gpc.get_ranks_in_group(input_parallel_mode) |
||||
src_rank = ranks_in_group[gpc.get_local_rank(output_parallel_mode)] |
||||
# [h/q^2] |
||||
bias_temp = bias.clone() |
||||
dist.broadcast(bias_temp, |
||||
src=src_rank, |
||||
group=gpc.get_group(input_parallel_mode)) |
||||
# [h/q] |
||||
bias_temp = all_gather(bias_temp, -1, weight_parallel_mode) |
||||
|
||||
empty_cache() |
||||
ctx.save_for_backward(input_, bias_temp) |
||||
|
||||
out = torch.mul(input_, bias_temp) |
||||
|
||||
ctx.depth = depth |
||||
ctx.src_rank = src_rank |
||||
ctx.A_group_parallel_mode = input_parallel_mode |
||||
ctx.B_group_parallel_mode = weight_parallel_mode |
||||
ctx.C_group_parallel_mode = output_parallel_mode |
||||
|
||||
return out |
||||
|
||||
@staticmethod |
||||
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]: |
||||
# output_grad: [m/q^2, n, h/q] |
||||
with torch.no_grad(): |
||||
input_, bias = ctx.saved_tensors |
||||
# [m/q^2, n, h/q] |
||||
input_grad = torch.mul(output_grad, bias) |
||||
# [h/q] |
||||
grad = torch.mul(output_grad, input_) |
||||
grad = torch.sum(grad, |
||||
dim=tuple(range(len(output_grad.shape))[:-1])) |
||||
bias_grad = reduce_scatter(grad, -1, ctx.B_group_parallel_mode) |
||||
dist.reduce(bias_grad, |
||||
dst=ctx.src_rank, |
||||
group=gpc.get_group(ctx.A_group_parallel_mode)) |
||||
if gpc.get_local_rank( |
||||
ctx.A_group_parallel_mode) != gpc.get_local_rank( |
||||
ctx.C_group_parallel_mode): |
||||
bias_grad = None |
||||
return input_grad, bias_grad, None, None, None, None |
||||
|
||||
|
||||
class Sum_3D(torch.autograd.Function): |
||||
"""Compute the sum of input tensors |
||||
""" |
||||
@staticmethod |
||||
def forward(ctx: Any, |
||||
input_: Tensor, |
||||
dim: int, |
||||
depth: int, |
||||
parallel_mode: ParallelMode, |
||||
keepdim: bool = False) -> Tensor: |
||||
# input: [m/q^2, n, h/q] |
||||
out = torch.sum(input_, dim=dim, keepdim=keepdim) |
||||
dist.all_reduce(out, group=gpc.get_group(parallel_mode)) |
||||
|
||||
ctx.input_shape = input_.shape |
||||
ctx.depth = depth |
||||
ctx.group = parallel_mode |
||||
ctx.dim = dim |
||||
return out |
||||
|
||||
@staticmethod |
||||
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]: |
||||
with torch.no_grad(): |
||||
output_grad = output_grad.contiguous() |
||||
dist.all_reduce(output_grad, group=gpc.get_group(ctx.group)) |
||||
if len(output_grad.shape) < len(ctx.input_shape): |
||||
output_grad = torch.unsqueeze(output_grad, ctx.dim) |
||||
dims = [1 for _ in range(len(output_grad.shape))] |
||||
dims[ctx.dim] = ctx.input_shape[ctx.dim] |
||||
input_grad = output_grad.repeat(tuple(dims)) |
||||
return input_grad, None, None, None, None, None |
||||
|
||||
|
||||
class Reduce_3D(torch.autograd.Function): |
||||
"""Reduce input tensors |
||||
""" |
||||
@staticmethod |
||||
def forward(ctx: Any, input_: Tensor, depth: int, |
||||
parallel_mode: ParallelMode) -> Tensor: |
||||
dist.all_reduce(input_, group=gpc.get_group(parallel_mode)) |
||||
return input_.clone() |
||||
|
||||
@staticmethod |
||||
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]: |
||||
return output_grad, None, None |
||||
|
||||
|
||||
class Slice_3D(torch.autograd.Function): |
||||
"""Slice input tensor |
||||
""" |
||||
@staticmethod |
||||
def forward(ctx: Any, input_: Tensor, dim: int, depth: int, |
||||
parallel_mode: ParallelMode) -> Tensor: |
||||
rank = gpc.get_local_rank(parallel_mode) |
||||
out = torch.chunk(input_, depth, dim=dim)[rank].contiguous() |
||||
|
||||
ctx.depth = depth |
||||
ctx.parallel_mode = parallel_mode |
||||
ctx.dim = dim |
||||
ctx.input_shape = input_.shape |
||||
|
||||
return out |
||||
|
||||
@staticmethod |
||||
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]: |
||||
with torch.no_grad(): |
||||
input_grad = all_gather(output_grad, ctx.dim, ctx.parallel_mode) |
||||
input_grad.reshape(ctx.input_shape) |
||||
return input_grad, None, None, None |
@ -0,0 +1,49 @@
|
||||
#!/usr/bin/env python |
||||
# -*- encoding: utf-8 -*- |
||||
|
||||
import os |
||||
|
||||
from colossalai.constants import DEPTH_3D |
||||
from colossalai.context.parallel_mode import ParallelMode |
||||
from colossalai.core import global_context as gpc |
||||
from torch import Tensor |
||||
|
||||
|
||||
def get_depth_from_env() -> int: |
||||
try: |
||||
depth = os.environ[DEPTH_3D] |
||||
depth = int(depth) |
||||
assert depth > 0, 'DEPTH must be greater than zero' |
||||
return depth |
||||
|
||||
except KeyError as e: |
||||
raise EnvironmentError( |
||||
'DEPTH is not found in the current environment, ' |
||||
'please make sure that you have used the correct process group initializer' |
||||
) |
||||
|
||||
|
||||
def get_last_group(a, b): |
||||
mapping = { |
||||
ParallelMode.PARALLEL_3D_INPUT: 'A', |
||||
ParallelMode.PARALLEL_3D_WEIGHT: 'B', |
||||
ParallelMode.PARALLEL_3D_OUTPUT: 'C', |
||||
} |
||||
|
||||
res = chr( |
||||
ord('A') + ord('B') + ord('C') - ord(mapping[a]) - ord(mapping[b])) |
||||
|
||||
if res == 'A': |
||||
return ParallelMode.PARALLEL_3D_INPUT |
||||
elif res == 'B': |
||||
return ParallelMode.PARALLEL_3D_WEIGHT |
||||
elif res == 'C': |
||||
return ParallelMode.PARALLEL_3D_OUTPUT |
||||
|
||||
|
||||
def dbg_check_shape(tensor: Tensor, shape: tuple): |
||||
rank = gpc.get_global_rank() |
||||
if rank == 0: |
||||
print(tensor.shape) |
||||
assert tensor.shape == shape, \ |
||||
'{} does not match {}'.format(tensor.shape, shape) |
@ -0,0 +1,368 @@
|
||||
import math |
||||
from typing import Tuple |
||||
|
||||
import torch |
||||
import torch.distributed as dist |
||||
from colossalai.context import ParallelMode, seed |
||||
from colossalai.core import global_context as gpc |
||||
from colossalai.registry import LAYERS |
||||
from colossalai.utils import checkpoint, get_current_device |
||||
from torch import Tensor, dtype, nn |
||||
|
||||
from .._common_utils import ACT2FN, divide, set_tensor_parallel_attribute |
||||
from ..vanilla_vision_transformer.layers import to_2tuple |
||||
from ._utils import get_depth_from_env |
||||
from .layers import Linear3D |
||||
|
||||
|
||||
@LAYERS.register_module |
||||
class ViTPatchEmbedding3D(nn.Module): |
||||
""" 3D Image to Patch Embedding |
||||
|
||||
:param img_size: iamge size |
||||
:type img_size: int |
||||
:param patch_size: patch size |
||||
:type patch_size: int |
||||
:param in_chans: number of channels of input image |
||||
:type in_chans: int |
||||
:param embed_size: dimension of embedding |
||||
:type embed_size: int |
||||
:param drop_prob: dropout probability |
||||
:type drop_prob: float |
||||
:param flatten: whether to flatten output tensor, defaults to True |
||||
:type flatten: bool, optional |
||||
""" |
||||
def __init__(self, |
||||
img_size: int, |
||||
patch_size: int, |
||||
in_chans: int, |
||||
embed_size: int, |
||||
drop_prob: float, |
||||
flatten: bool = True): |
||||
super().__init__() |
||||
self.depth = get_depth_from_env() |
||||
self.input_parallel_mode = ParallelMode.PARALLEL_3D_INPUT |
||||
self.weight_parallel_mode = ParallelMode.PARALLEL_3D_WEIGHT |
||||
self.output_parallel_mode = ParallelMode.PARALLEL_3D_OUTPUT |
||||
img_size = to_2tuple(img_size) |
||||
patch_size = to_2tuple(patch_size) |
||||
self.img_size = img_size |
||||
self.patch_size = patch_size |
||||
self.grid_size = (img_size[0] // patch_size[0], |
||||
img_size[1] // patch_size[1]) |
||||
self.embed_size = embed_size |
||||
self.embed_size_per_partition = divide(self.embed_size, self.depth) |
||||
self.num_patches = self.grid_size[0] * self.grid_size[1] |
||||
self.flatten = flatten |
||||
|
||||
with seed(ParallelMode.TENSOR): |
||||
self.proj = nn.Conv2d(in_chans, |
||||
self.embed_size_per_partition, |
||||
kernel_size=patch_size, |
||||
stride=patch_size) |
||||
|
||||
self.cls_token = nn.Parameter( |
||||
torch.zeros(1, 1, self.embed_size_per_partition)) |
||||
self.pos_embed = nn.Parameter( |
||||
torch.zeros(1, self.num_patches + 1, |
||||
self.embed_size_per_partition)) |
||||
self.pos_drop = nn.Dropout(drop_prob) |
||||
|
||||
self._sync_parameters() |
||||
self.proj.weight.register_hook(self._sync_grad_hook) |
||||
self.proj.bias.register_hook(self._sync_grad_hook) |
||||
self.cls_token.register_hook(self._sync_grad_hook) |
||||
self.pos_embed.register_hook(self._sync_grad_hook) |
||||
self._set_tensor_parallel_attribute() |
||||
|
||||
def _set_tensor_parallel_attribute(self): |
||||
set_tensor_parallel_attribute(self.proj.weight) |
||||
set_tensor_parallel_attribute(self.proj.bias) |
||||
set_tensor_parallel_attribute(self.cls_token) |
||||
set_tensor_parallel_attribute(self.pos_embed) |
||||
|
||||
def groups_for_next_layer(self) -> Tuple[ParallelMode, ParallelMode]: |
||||
return self.input_parallel_mode, self.weight_parallel_mode |
||||
|
||||
def _sync_parameters(self): |
||||
self.to(get_current_device()) |
||||
weight_src_rank = gpc.get_ranks_in_group(self.weight_parallel_mode)[0] |
||||
dist.broadcast(self.proj.weight, |
||||
src=weight_src_rank, |
||||
group=gpc.get_group(self.weight_parallel_mode)) |
||||
dist.broadcast(self.proj.bias, |
||||
src=weight_src_rank, |
||||
group=gpc.get_group(self.weight_parallel_mode)) |
||||
input_src_rank = gpc.get_ranks_in_group(self.input_parallel_mode)[0] |
||||
dist.broadcast(self.proj.weight, |
||||
src=input_src_rank, |
||||
group=gpc.get_group(self.input_parallel_mode)) |
||||
dist.broadcast(self.proj.bias, |
||||
src=input_src_rank, |
||||
group=gpc.get_group(self.input_parallel_mode)) |
||||
set_tensor_parallel_attribute(self.proj.weight) |
||||
set_tensor_parallel_attribute(self.proj.bias) |
||||
set_tensor_parallel_attribute(self.cls_token) |
||||
set_tensor_parallel_attribute(self.pos_embed) |
||||
|
||||
def _sync_grad_hook(self, grad) -> None: |
||||
dist.all_reduce(grad, group=gpc.get_group(self.input_parallel_mode)) |
||||
dist.all_reduce(grad, group=gpc.get_group(self.weight_parallel_mode)) |
||||
return grad |
||||
|
||||
def forward(self, x: Tensor) -> Tensor: |
||||
B, C, H, W = x.shape |
||||
assert H == self.img_size[0] and W == self.img_size[1], \ |
||||
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." |
||||
x = self.proj(x) |
||||
if self.flatten: |
||||
x = x.flatten(2).transpose(1, 2) # BCHW -> BNC |
||||
|
||||
# split a partition from embedded states |
||||
x = torch.chunk(x, self.depth, dim=0)[gpc.get_local_rank( |
||||
self.weight_parallel_mode)].contiguous() |
||||
x = torch.chunk(x, self.depth, dim=0)[gpc.get_local_rank( |
||||
self.input_parallel_mode)].contiguous() |
||||
|
||||
# add cls token & pos embedding |
||||
# [b/q^2,s,h/q] --> [b/q^2, 1+s, h/q] |
||||
cls_token = self.cls_token.expand(x.shape[0], -1, -1) |
||||
x = torch.cat((cls_token, x), dim=1) |
||||
|
||||
with seed(ParallelMode.TENSOR): |
||||
x = self.pos_drop(x + self.pos_embed) |
||||
|
||||
return x |
||||
|
||||
|
||||
@LAYERS.register_module |
||||
class ViTSelfAttention3D(nn.Module): |
||||
"""Self-attention layer for 3D parallel Vision Transformer |
||||
|
||||
:param hidden_size: hidden size |
||||
:type hidden_size: int |
||||
:param num_attention_heads: number of attention heads |
||||
:type num_attention_heads: int |
||||
:param attention_probs_dropout_prob: dropout probability for attention layers |
||||
:type attention_probs_dropout_prob: bool |
||||
:param hidden_dropout_prob: dropout probability for hidden layers |
||||
:type hidden_dropout_prob: bool |
||||
:param depth: the 3D parallelism depth |
||||
:type depth: int |
||||
:param input_parallel_mode: parallel mode of input tensor |
||||
:type input_parallel_mode: ParallelMode |
||||
:param weight_parallel_mode: parallel mode of weight |
||||
:type weight_parallel_mode: ParallelMode |
||||
:param dtype: dtype of parameters, defaults to None |
||||
:type dtype: dtype, optional |
||||
:param bias: whether to add bias, defaults to True |
||||
:type bias: bool, optional |
||||
""" |
||||
def __init__(self, |
||||
hidden_size: int, |
||||
num_attention_heads: int, |
||||
attention_probs_dropout_prob: float, |
||||
hidden_dropout_prob: float, |
||||
dtype: dtype = None, |
||||
bias: bool = True, |
||||
checkpoint: bool = False): |
||||
super().__init__() |
||||
self.depth = get_depth_from_env() |
||||
self.input_parallel_mode = ParallelMode.PARALLEL_3D_INPUT |
||||
self.weight_parallel_mode = ParallelMode.PARALLEL_3D_WEIGHT |
||||
self.output_parallel_mode = ParallelMode.PARALLEL_3D_OUTPUT |
||||
self.hidden_size = hidden_size |
||||
self.num_attention_heads = divide(num_attention_heads, self.depth) |
||||
self.attention_head_size = divide(hidden_size, num_attention_heads) |
||||
self.all_head_size = self.num_attention_heads * self.attention_head_size |
||||
self.checkpoint = checkpoint |
||||
|
||||
self.query_key_value = Linear3D(self.hidden_size, |
||||
3 * self.hidden_size, |
||||
self.input_parallel_mode, |
||||
self.weight_parallel_mode, |
||||
dtype=dtype, |
||||
bias=bias) |
||||
self.attention_dropout = nn.Dropout(attention_probs_dropout_prob) |
||||
self.dense = Linear3D(self.hidden_size, |
||||
self.hidden_size, |
||||
self.output_parallel_mode, |
||||
self.weight_parallel_mode, |
||||
dtype=dtype, |
||||
bias=bias) |
||||
self.dropout = nn.Dropout(hidden_dropout_prob) |
||||
self.softmax = nn.Softmax(dim=-1) |
||||
|
||||
def groups_for_next_layer(self) -> Tuple[ParallelMode, ParallelMode]: |
||||
return self.input_parallel_mode, self.weight_parallel_mode |
||||
|
||||
def _forward(self, hidden_states: Tensor) -> Tensor: |
||||
query_key_value = self.query_key_value(hidden_states) |
||||
new_qkv_shape = query_key_value.shape[:-1] + \ |
||||
(self.num_attention_heads, 3 * self.attention_head_size) |
||||
query_key_value = query_key_value.view(new_qkv_shape) |
||||
query_key_value = query_key_value.permute((0, 2, 1, 3)) |
||||
query_layer, key_layer, value_layer = torch.chunk(query_key_value, |
||||
3, |
||||
dim=-1) |
||||
|
||||
attention_scores = torch.matmul(query_layer, |
||||
key_layer.transpose(-1, -2)) |
||||
attention_scores = attention_scores / math.sqrt( |
||||
self.attention_head_size) |
||||
attention_probs = self.softmax(attention_scores) |
||||
with seed(ParallelMode.TENSOR): |
||||
attention_probs = self.attention_dropout(attention_probs) |
||||
|
||||
context_layer = torch.matmul(attention_probs, value_layer) |
||||
context_layer = context_layer.transpose(1, 2) |
||||
new_context_layer_shape = context_layer.size()[:-2] + ( |
||||
self.all_head_size, ) |
||||
context_layer = context_layer.reshape(new_context_layer_shape) |
||||
|
||||
output = self.dense(context_layer) |
||||
with seed(ParallelMode.TENSOR): |
||||
output = self.dropout(output) |
||||
|
||||
return output |
||||
|
||||
def _checkpoint_forward(self, hidden_states: Tensor) -> Tensor: |
||||
return checkpoint(self._forward, hidden_states) |
||||
|
||||
def forward(self, hidden_states: Tensor) -> Tensor: |
||||
if self.checkpoint: |
||||
return self._checkpoint_forward(hidden_states) |
||||
else: |
||||
return self._forward(hidden_states) |
||||
|
||||
|
||||
@LAYERS.register_module |
||||
class ViTMLP3D(nn.Module): |
||||
"""[summary] |
||||
|
||||
:param hidden_size: hidden size |
||||
:type hidden_size: int |
||||
:param mlp_ratio: hidden size of MLP divided by embedding dim |
||||
:type mlp_ratio: int |
||||
:param hidden_dropout_prob: dropout probability for hidden layers |
||||
:type hidden_dropout_prob: float |
||||
:param hidden_act: activation function for hidden layers |
||||
:type hidden_act: str |
||||
:param depth: the 3D parallelism depth |
||||
:type depth: int |
||||
:param input_parallel_mode: parallel mode of input tensor |
||||
:type input_parallel_mode: ParallelMode |
||||
:param weight_parallel_mode: parallel mode of weight |
||||
:type weight_parallel_mode: ParallelMode |
||||
:param dtype: dtype of parameters, defaults to None |
||||
:type dtype: dtype, optional |
||||
:param bias: whether to add bias, defaults to True |
||||
:type bias: bool, optional |
||||
""" |
||||
def __init__(self, |
||||
hidden_size: int, |
||||
mlp_ratio: int, |
||||
hidden_dropout_prob: float, |
||||
hidden_act: str = 'gelu', |
||||
dtype: dtype = None, |
||||
bias: bool = True, |
||||
checkpoint: bool = False): |
||||
super().__init__() |
||||
self.depth = get_depth_from_env() |
||||
self.input_parallel_mode = ParallelMode.PARALLEL_3D_INPUT |
||||
self.weight_parallel_mode = ParallelMode.PARALLEL_3D_WEIGHT |
||||
self.output_parallel_mode = ParallelMode.PARALLEL_3D_OUTPUT |
||||
self.hidden_size = hidden_size |
||||
self.mlp_ratio = mlp_ratio |
||||
self.checkpoint = checkpoint |
||||
|
||||
self.dense_1 = Linear3D(self.hidden_size, |
||||
self.mlp_ratio * self.hidden_size, |
||||
self.input_parallel_mode, |
||||
self.weight_parallel_mode, |
||||
dtype=dtype, |
||||
bias=bias) |
||||
self.activation_func = ACT2FN[hidden_act] |
||||
self.dense_2 = Linear3D(self.mlp_ratio * self.hidden_size, |
||||
self.hidden_size, |
||||
self.output_parallel_mode, |
||||
self.weight_parallel_mode, |
||||
dtype=dtype, |
||||
bias=bias) |
||||
self.dropout = nn.Dropout(hidden_dropout_prob) |
||||
|
||||
def groups_for_next_layer(self) -> Tuple[ParallelMode, ParallelMode]: |
||||
return self.input_parallel_mode, self.weight_parallel_mode |
||||
|
||||
def _forward(self, hidden_states: Tensor) -> Tensor: |
||||
intermediate_output = self.dense_1(hidden_states) |
||||
intermediate_output = self.activation_func(intermediate_output) |
||||
output = self.dense_2(intermediate_output) |
||||
with seed(ParallelMode.TENSOR): |
||||
output = self.dropout(output) |
||||
return output |
||||
|
||||
def _checkpoint_forward(self, hidden_states: Tensor) -> Tensor: |
||||
return checkpoint(self._forward, hidden_states) |
||||
|
||||
def forward(self, hidden_states: Tensor) -> Tensor: |
||||
if self.checkpoint: |
||||
return self._checkpoint_forward(hidden_states) |
||||
else: |
||||
return self._forward(hidden_states) |
||||
|
||||
|
||||
@LAYERS.register_module |
||||
class ViTHead3D(nn.Module): |
||||
"""Output layer for 3D parallel Vision Transformer |
||||
|
||||
:param in_features: size of input tensor |
||||
:type in_features: int |
||||
:param num_classes: number of classes |
||||
:type num_classes: int |
||||
:param depth: the 3D parallelism depth |
||||
:type depth: int |
||||
:param input_parallel_mode: parallel mode of input tensor |
||||
:type input_parallel_mode: ParallelMode |
||||
:param weight_parallel_mode: parallel mode of weight |
||||
:type weight_parallel_mode: ParallelMode |
||||
:param dtype: dtype of parameters, defaults to None |
||||
:type dtype: dtype, optional |
||||
:param bias: whether to add bias, defaults to True |
||||
:type bias: bool, optional |
||||
""" |
||||
def __init__(self, |
||||
in_features: int, |
||||
num_classes: int, |
||||
dtype: dtype = None, |
||||
bias: bool = True): |
||||
super().__init__() |
||||
self.depth = get_depth_from_env() |
||||
self.input_parallel_mode = ParallelMode.PARALLEL_3D_INPUT |
||||
self.weight_parallel_mode = ParallelMode.PARALLEL_3D_WEIGHT |
||||
self.output_parallel_mode = ParallelMode.PARALLEL_3D_OUTPUT |
||||
self.in_features = in_features |
||||
self.num_classes = num_classes |
||||
out_features = math.ceil(self.num_classes / |
||||
(self.depth**2)) * (self.depth**2) |
||||
self.num_classes_per_partition = divide(self.num_classes, self.depth) |
||||
self.linear = Linear3D(self.in_features, |
||||
out_features, |
||||
self.input_parallel_mode, |
||||
self.weight_parallel_mode, |
||||
dtype=dtype, |
||||
bias=bias) |
||||
|
||||
def groups_for_next_layer(self) -> Tuple[ParallelMode, ParallelMode]: |
||||
return self.linear.groups_for_next_layer() |
||||
|
||||
def forward(self, x: Tensor) -> Tensor: |
||||
# [b/q^2, s, h/q] --> [b/q^2, h/q] |
||||
x = x[:, 0] |
||||
# [b/q^2, h/q] --> [b/q^2, c/q] |
||||
x = self.linear(x) |
||||
return x[:, :self.num_classes_per_partition] |
||||
|
||||
def extra_repr(self): |
||||
return 'in_features={}, num_classes={}'.format(self.in_features, |
||||
self.num_classes) |
@ -0,0 +1,172 @@
|
||||
#!/usr/bin/env python |
||||
# -*- encoding: utf-8 -*- |
||||
|
||||
import math |
||||
from typing import Tuple |
||||
|
||||
import torch |
||||
import torch.nn as nn |
||||
from colossalai.context import ParallelMode, seed |
||||
from colossalai.registry import LAYERS |
||||
from colossalai.utils import get_current_device |
||||
from torch import Tensor, dtype |
||||
from torch.nn import Parameter |
||||
|
||||
from .._common_utils import divide, set_tensor_parallel_attribute |
||||
from ._operation import Add_3D, Matmul_AB_3D, Mul_3D, Sum_3D |
||||
from ._utils import get_depth_from_env, get_last_group |
||||
|
||||
|
||||
@LAYERS.register_module |
||||
class LayerNorm3D(nn.Module): |
||||
def __init__( |
||||
self, |
||||
normalized_shape: int, |
||||
input_parallel_mode: ParallelMode, |
||||
weight_parallel_mode: ParallelMode, |
||||
eps: float = 1e-12, |
||||
dtype: dtype = None, |
||||
): |
||||
super().__init__() |
||||
self.input_parallel_mode = input_parallel_mode |
||||
self.weight_parallel_mode = weight_parallel_mode |
||||
self.output_parallel_mode = get_last_group(self.input_parallel_mode, |
||||
self.weight_parallel_mode) |
||||
self.depth = get_depth_from_env() |
||||
self.normalized_shape = normalized_shape |
||||
self.normalized_shape_per_partition = divide(normalized_shape, |
||||
self.depth**2) |
||||
|
||||
self.weight = Parameter( |
||||
torch.ones(self.normalized_shape_per_partition, |
||||
device=get_current_device(), |
||||
dtype=dtype)) |
||||
self.bias = Parameter( |
||||
torch.zeros(self.normalized_shape_per_partition, |
||||
device=get_current_device(), |
||||
dtype=dtype)) |
||||
self.variance_epsilon = eps |
||||
self._set_tensor_parallel_attributes() |
||||
|
||||
def _set_tensor_parallel_attributes(self): |
||||
set_tensor_parallel_attribute(self.weight) |
||||
set_tensor_parallel_attribute(self.bias) |
||||
|
||||
def groups_for_next_layer(self) -> Tuple[ParallelMode, ParallelMode]: |
||||
return self.input_parallel_mode, self.weight_parallel_mode |
||||
|
||||
def reset_parameters(self): |
||||
nn.init.zeros_(self.bias) |
||||
nn.init.ones_(self.weight) |
||||
|
||||
def forward(self, input_: Tensor) -> Tensor: |
||||
'''x = weight * (x - mean) / sqrt(var + eps) + bias''' |
||||
# input: [m/q^2, n, h/q] |
||||
# [m/q^2, n, 1] |
||||
mean = Sum_3D.apply(input_, -1, self.depth, self.output_parallel_mode, |
||||
True) / self.normalized_shape |
||||
# [m/q^2, n, 1] |
||||
var = (input_ - mean).pow(2) |
||||
var = Sum_3D.apply(var, -1, self.depth, self.output_parallel_mode, |
||||
True) / self.normalized_shape |
||||
|
||||
output = (input_ - mean) / torch.sqrt(var + self.variance_epsilon) |
||||
output = Mul_3D.apply(output, self.weight, self.depth, |
||||
self.input_parallel_mode, |
||||
self.weight_parallel_mode, |
||||
self.output_parallel_mode) |
||||
output = Add_3D.apply(output, self.bias, self.depth, |
||||
self.input_parallel_mode, |
||||
self.weight_parallel_mode, |
||||
self.output_parallel_mode) |
||||
return output |
||||
|
||||
def extra_repr(self): |
||||
return '{}, eps={}'.format(self.normalized_shape, |
||||
self.variance_epsilon) |
||||
|
||||
|
||||
@LAYERS.register_module |
||||
class Linear3D(nn.Module): |
||||
def __init__(self, |
||||
in_features: int, |
||||
out_features: int, |
||||
input_parallel_mode: ParallelMode, |
||||
weight_parallel_mode: ParallelMode, |
||||
bias: bool = True, |
||||
dtype: dtype = None): |
||||
super().__init__() |
||||
self.in_features = in_features |
||||
self.out_features = out_features |
||||
self.input_parallel_mode = input_parallel_mode |
||||
self.weight_parallel_mode = weight_parallel_mode |
||||
self.output_parallel_mode = get_last_group(self.input_parallel_mode, |
||||
self.weight_parallel_mode) |
||||
self.with_bias = bias |
||||
self.depth = get_depth_from_env() |
||||
self.in_features_per_partition = divide(in_features, self.depth) |
||||
self.out_features_per_partition = divide(out_features, self.depth**2) |
||||
|
||||
# [k/q, h/q^2] |
||||
self.weight = Parameter( |
||||
torch.empty(self.in_features_per_partition, |
||||
self.out_features_per_partition, |
||||
device=get_current_device(), |
||||
dtype=dtype)) |
||||
|
||||
# [h/q^2] |
||||
if bias: |
||||
self.bias = Parameter( |
||||
torch.zeros(self.out_features_per_partition, |
||||
device=get_current_device(), |
||||
dtype=dtype)) |
||||
else: |
||||
self.register_parameter('bias', None) |
||||
|
||||
self.reset_parameters() |
||||
self._set_tensor_parallel_attributes() |
||||
|
||||
def _set_tensor_parallel_attributes(self): |
||||
set_tensor_parallel_attribute(self.weight) |
||||
if self.bias is not None: |
||||
set_tensor_parallel_attribute(self.bias) |
||||
|
||||
def groups_for_next_layer(self) -> Tuple[ParallelMode, ParallelMode]: |
||||
return self.output_parallel_mode, self.weight_parallel_mode |
||||
|
||||
def reset_parameters(self): |
||||
# setting |
||||
fan_in = self.in_features |
||||
a = math.sqrt(5) |
||||
nonlinearity = 'leaky_relu' |
||||
|
||||
# init weight |
||||
std = nn.init.calculate_gain(nonlinearity, a) / math.sqrt(fan_in) |
||||
bound = math.sqrt(3.0) * std |
||||
with seed(ParallelMode.TENSOR): |
||||
nn.init.uniform_(self.weight, -bound, bound) |
||||
|
||||
# init bias |
||||
if self.with_bias: |
||||
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 |
||||
with seed(ParallelMode.TENSOR): |
||||
nn.init.uniform_(self.bias, -bound, bound) |
||||
|
||||
def forward(self, input_: Tensor) -> Tensor: |
||||
# input: [m/q^2, n, k/q] |
||||
# output: [m/q^2, n, h/q] |
||||
output = Matmul_AB_3D.apply(input_, self.weight, self.depth, |
||||
self.input_parallel_mode, |
||||
self.weight_parallel_mode, |
||||
self.output_parallel_mode) |
||||
|
||||
if self.with_bias: |
||||
output = Add_3D.apply(output, self.bias, self.depth, |
||||
self.output_parallel_mode, |
||||
self.weight_parallel_mode, |
||||
self.input_parallel_mode) |
||||
return output |
||||
|
||||
def extra_repr(self): |
||||
return 'in_features={}, out_features={}, bias={}'.format( |
||||
self.in_features, self.out_features, self.with_bias) |
@ -0,0 +1,4 @@
|
||||
from ._operation import RingQK, RingAV |
||||
from .layers import TransformerSelfAttentionRing |
||||
|
||||
__all__ = ['TransformerSelfAttentionRing', 'RingAV', 'RingQK'] |
@ -0,0 +1,169 @@
|
||||
#!/usr/bin/env python |
||||
# -*- encoding: utf-8 -*- |
||||
|
||||
import torch |
||||
from torch import distributed as dist |
||||
|
||||
from colossalai.communication import ring_forward |
||||
from colossalai.context.parallel_mode import ParallelMode |
||||
from colossalai.core import global_context as gpc |
||||
from colossalai.nn.layer.parallel_sequence._utils import _calc_incoming_device_range, _calc_current_device_range |
||||
from colossalai.utils import get_current_device |
||||
|
||||
|
||||
class RingQK(torch.autograd.Function): |
||||
""" |
||||
Calculate QK in a ring-exchange style |
||||
""" |
||||
|
||||
@staticmethod |
||||
def forward(ctx, |
||||
sub_q, |
||||
sub_k, |
||||
batch_size, |
||||
num_attention_heads, |
||||
sub_seq_length): |
||||
# save tensor for backward |
||||
ctx.save_for_backward(sub_q, sub_k) |
||||
ctx.sub_seq_length = sub_seq_length |
||||
|
||||
# create local segment of attention score |
||||
attention_score = torch.empty( |
||||
batch_size * num_attention_heads, |
||||
sub_seq_length, |
||||
sub_seq_length * gpc.get_world_size(ParallelMode.SEQUENCE), |
||||
dtype=sub_q.dtype, |
||||
device=get_current_device() |
||||
) |
||||
|
||||
# compute local QK^T |
||||
part_a = torch.matmul(sub_q, sub_k.transpose(2, 1)) |
||||
local_rank = gpc.get_local_rank(ParallelMode.SEQUENCE) |
||||
local_world_size = gpc.get_world_size(ParallelMode.SEQUENCE) |
||||
start_idx = local_rank * sub_seq_length |
||||
end_idx = (local_rank + 1) * sub_seq_length |
||||
attention_score[:, :, start_idx: end_idx] = part_a |
||||
|
||||
# compute QK^T in ring-all-reduce style |
||||
for i in range(local_world_size - 1): |
||||
sub_k = ring_forward(sub_k, ParallelMode.SEQUENCE) |
||||
start_idx, end_idx = _calc_incoming_device_range(i, local_rank, local_world_size, sub_seq_length) |
||||
part_a = torch.matmul(sub_q, sub_k.transpose(2, 1)) |
||||
attention_score[:, :, start_idx:end_idx] = part_a |
||||
|
||||
return attention_score |
||||
|
||||
@staticmethod |
||||
def backward(ctx, grad_output): |
||||
sub_q, sub_k, = ctx.saved_tensors |
||||
local_rank = gpc.get_local_rank(ParallelMode.SEQUENCE) |
||||
local_world_size = gpc.get_world_size(ParallelMode.SEQUENCE) |
||||
|
||||
# calculate gradient of sub_k |
||||
grad_k = torch.matmul( |
||||
grad_output.transpose(2, 1), |
||||
sub_q |
||||
) |
||||
dist.all_reduce(grad_k, group=gpc.get_group(ParallelMode.SEQUENCE)) |
||||
grad_k = grad_k[:, local_rank * ctx.sub_seq_length: (local_rank + 1) * ctx.sub_seq_length] |
||||
grad_k /= local_world_size |
||||
|
||||
# calculate gradient for sub_q |
||||
grad_q = torch.zeros_like(sub_q, |
||||
dtype=sub_q.dtype, |
||||
device=get_current_device(), ) |
||||
|
||||
# compute with local sub_k |
||||
start_idx, end_idx = _calc_current_device_range(local_rank, ctx.sub_seq_length) |
||||
grad_q += torch.matmul(grad_output[:, :, start_idx:end_idx], sub_k) |
||||
|
||||
# compute QK^T in ring-all-reduce style |
||||
for i in range(local_world_size - 1): |
||||
sub_k = ring_forward(sub_k, ParallelMode.SEQUENCE) |
||||
start_idx, end_idx = _calc_incoming_device_range(i, local_rank, local_world_size, ctx.sub_seq_length) |
||||
grad_q += torch.matmul(grad_output[:, :, start_idx: end_idx], sub_k) |
||||
|
||||
grad_q /= local_world_size |
||||
|
||||
return grad_q, grad_k, None, None, None |
||||
|
||||
|
||||
class RingAV(torch.autograd.Function): |
||||
""" |
||||
Calculate AV in a ring-exchange style |
||||
""" |
||||
|
||||
@staticmethod |
||||
def forward(ctx, |
||||
attention_score, |
||||
sub_v, |
||||
batch_size, |
||||
num_attention_heads, |
||||
attention_head_size, |
||||
sub_seq_length): |
||||
local_rank = gpc.get_local_rank(ParallelMode.SEQUENCE) |
||||
local_world_size = gpc.get_world_size(ParallelMode.SEQUENCE) |
||||
local_start_idx, local_end_idx = _calc_current_device_range(local_rank, sub_seq_length) |
||||
|
||||
sub_attention_result = torch.zeros( |
||||
batch_size * num_attention_heads, |
||||
sub_seq_length, |
||||
attention_head_size, |
||||
device=get_current_device(), |
||||
dtype=attention_score.dtype) |
||||
|
||||
# save tensors for backward |
||||
ctx.save_for_backward(attention_score, sub_v) |
||||
ctx.sub_seq_length = sub_seq_length |
||||
|
||||
# compute local AV |
||||
part_av = torch.matmul(attention_score[:, :, local_start_idx:local_end_idx], sub_v) |
||||
sub_attention_result += part_av |
||||
|
||||
# compute AV in ring - all - reduce style |
||||
for i in range(local_world_size - 1): |
||||
sub_v = ring_forward(sub_v, ParallelMode.SEQUENCE) |
||||
start_idx, end_idx = _calc_incoming_device_range(i, local_rank, local_world_size, sub_seq_length) |
||||
|
||||
# compute QK^T |
||||
part_av = torch.matmul(attention_score[:, :, start_idx:end_idx], sub_v) |
||||
sub_attention_result += part_av |
||||
return sub_attention_result |
||||
|
||||
@staticmethod |
||||
def backward(ctx, grad_output): |
||||
local_rank = gpc.get_local_rank(ParallelMode.SEQUENCE) |
||||
local_world_size = gpc.get_world_size(ParallelMode.SEQUENCE) |
||||
local_start_idx, local_end_idx = _calc_current_device_range(local_rank, ctx.sub_seq_length) |
||||
attention_scores, sub_v = ctx.saved_tensors |
||||
|
||||
# calculate gradient of v |
||||
grad_v = torch.matmul( |
||||
attention_scores.transpose(2, 1), |
||||
grad_output |
||||
) |
||||
dist.all_reduce(grad_v, group=gpc.get_group(ParallelMode.SEQUENCE)) |
||||
grad_v = grad_v[:, local_start_idx:local_end_idx] |
||||
grad_v /= local_world_size |
||||
|
||||
# calculate gradient for attention score |
||||
grad_attention_score = torch.zeros_like(attention_scores, |
||||
dtype=grad_output.dtype, |
||||
device=get_current_device()) |
||||
|
||||
# compute with local sub_k |
||||
grad_attention_score[:, :, local_start_idx:local_end_idx] += torch.matmul( |
||||
grad_output, |
||||
sub_v.transpose(2, 1)) |
||||
|
||||
# compute QK^T in ring-all-reduce style |
||||
for i in range(local_world_size - 1): |
||||
sub_v = ring_forward(sub_v, ParallelMode.SEQUENCE) |
||||
start_idx, end_idx = _calc_incoming_device_range(i, local_rank, local_world_size, ctx.sub_seq_length) |
||||
|
||||
# compute grad_q |
||||
grad_attention_score[:, :, start_idx:end_idx] += torch.matmul( |
||||
grad_output, |
||||
sub_v.transpose(2, 1)) |
||||
|
||||
return grad_attention_score, grad_v, None, None, None, None |
@ -0,0 +1,15 @@
|
||||
#!/usr/bin/env python |
||||
# -*- encoding: utf-8 -*- |
||||
|
||||
|
||||
def _calc_incoming_device_range(i, rank, world_size, sub_seq_length): |
||||
device_of_incoming_k = (rank - i - 1) % world_size |
||||
start_idx = sub_seq_length * device_of_incoming_k |
||||
end_idx = sub_seq_length * (device_of_incoming_k + 1) |
||||
return start_idx, end_idx |
||||
|
||||
|
||||
def _calc_current_device_range(rank, sub_seq_length): |
||||
start_idx = sub_seq_length * rank |
||||
end_idx = sub_seq_length * (rank + 1) |
||||
return start_idx, end_idx |
@ -0,0 +1,188 @@
|
||||
#!/usr/bin/env python |
||||
# -*- encoding: utf-8 -*- |
||||
|
||||
import math |
||||
|
||||
import torch |
||||
import torch.nn as nn |
||||
import torch.nn.functional as F |
||||
|
||||
from colossalai.context.parallel_mode import ParallelMode |
||||
from colossalai.core import global_context as gpc |
||||
from colossalai.nn.layer.parallel_sequence._operation import RingQK, RingAV |
||||
from colossalai.registry import LAYERS |
||||
|
||||
|
||||
@LAYERS.register_module |
||||
class TransformerSelfAttentionRing(nn.Module): |
||||
"""Parallel self-attention layer abstract class. |
||||
Self-attention layer takes input with size [b, s, h] |
||||
and returns output of the same size. |
||||
|
||||
:param hidden_size: hidden size |
||||
:type hidden_size: int |
||||
:param kv_channels: channels of key/value tensor |
||||
:type kv_channels: int |
||||
:param num_attention_heads: number of attention heads |
||||
:type num_attention_heads: int |
||||
:param attention_dropout: dropout probability for attention layer |
||||
:type attention_dropout: float |
||||
""" |
||||
|
||||
def __init__(self, |
||||
hidden_size, |
||||
kv_channels, |
||||
num_attention_heads, |
||||
attention_dropout, |
||||
): |
||||
super().__init__() |
||||
|
||||
self.hidden_size = hidden_size |
||||
self.num_attention_heads = num_attention_heads |
||||
|
||||
projection_size = kv_channels * num_attention_heads |
||||
self.hidden_size_per_attention_head = projection_size // num_attention_heads |
||||
|
||||
self.world_size = gpc.get_world_size(ParallelMode.SEQUENCE) |
||||
|
||||
# Strided linear layer. |
||||
self.query_key_value = nn.Linear( |
||||
hidden_size, |
||||
3 * projection_size, |
||||
) |
||||
|
||||
# coeff = None |
||||
self.norm_factor = math.sqrt(self.hidden_size) |
||||
|
||||
# TODO: add apply_query_key_layer_scaling when we have the kernel module |
||||
# if self.apply_query_key_layer_scaling: |
||||
# coeff = self.layer_number |
||||
# self.norm_factor *= coeff |
||||
|
||||
# TODO: add fused scale mask softmax kernel when we have the kernel module |
||||
# self.scale_mask_softmax = FusedScaleMaskSoftmax( |
||||
# self.fp16, self.bf16, |
||||
# self.attn_mask_type, |
||||
# masked_softmax_fusion, |
||||
# attention_mask_func, |
||||
# self.attention_softmax_in_fp32, |
||||
# coeff) |
||||
|
||||
self.attention_dropout = nn.Dropout(attention_dropout) |
||||
|
||||
# Output. |
||||
self.dense = nn.Linear( |
||||
projection_size, |
||||
hidden_size, |
||||
bias=True) |
||||
|
||||
def forward(self, hidden_states, attention_mask): |
||||
# hidden_states: [sq, b, h] |
||||
|
||||
sub_seq_length, batch_size, hidden_size = hidden_states.size() |
||||
|
||||
# ===================== |
||||
# Query, Key, and Value |
||||
# ===================== |
||||
|
||||
# Attention heads [sq, b, h] --> [sq, b, (3 * hn * num_heads)] |
||||
mixed_x_layer = self.query_key_value(hidden_states) |
||||
|
||||
# [sq, b, num_heads, 3 * hn] --> 3 [sq, b, num_heads, hn] |
||||
new_tensor_shape = mixed_x_layer.size()[:-1] + (self.num_attention_heads, |
||||
3 * self.hidden_size_per_attention_head) |
||||
mixed_x_layer = mixed_x_layer.view(*new_tensor_shape) |
||||
|
||||
# split into query, key and value |
||||
last_dim = mixed_x_layer.dim() - 1 |
||||
last_dim_value = mixed_x_layer.size()[-1] |
||||
assert last_dim_value % 3 == 0, 'the last dimension is not a multiple of 3, ' \ |
||||
'cannot be divided into query, key and value' |
||||
partition_size = last_dim_value // 3 |
||||
(query_layer, key_layer, value_layer) = torch.split( |
||||
mixed_x_layer, partition_size, dim=last_dim) |
||||
|
||||
# =================================== |
||||
# Raw attention scores. [b, num_heads, s, s] |
||||
# =================================== |
||||
|
||||
# [b, num_heads, sq, sk] |
||||
output_size = (query_layer.size(1), |
||||
query_layer.size(2), |
||||
query_layer.size(0), |
||||
key_layer.size(0) * self.world_size) |
||||
|
||||
# [sq, b, num_heads, hn] -> [sq, b * num_heads, hn] |
||||
query_layer = query_layer.view(output_size[2], |
||||
output_size[0] * output_size[1], -1) |
||||
# [sk, b, num_heads, hn] -> [sk, b * num_heads, hn] |
||||
key_layer = key_layer.view(key_layer.size(0), |
||||
output_size[0] * output_size[1], -1) |
||||
|
||||
# [b, sq, sk] |
||||
attention_scores = RingQK.apply( |
||||
# [b * num_heads, sq, hn] |
||||
query_layer.transpose(0, 1).contiguous(), |
||||
key_layer.transpose(0, 1).contiguous(), # [b * num_heads, sk, hn], |
||||
batch_size, |
||||
self.num_attention_heads, |
||||
sub_seq_length |
||||
) |
||||
attention_scores /= self.norm_factor |
||||
|
||||
# change view to [b, num_heads, sq, sk] |
||||
attention_scores = attention_scores.view(*output_size) |
||||
attention_scores = attention_scores.unsqueeze(1) |
||||
|
||||
attention_scores = attention_scores + attention_mask |
||||
attention_probs = F.softmax(attention_scores, dim=-1) |
||||
attention_probs = attention_probs.squeeze(1) |
||||
|
||||
# This is actually dropping out entire tokens to attend to, which might |
||||
# seem a bit unusual, but is taken from the original Transformer paper. |
||||
# with mpu.get_cuda_rng_tracker().fork(): |
||||
# TODO: check if a rng tracker is needed |
||||
attention_probs = self.attention_dropout(attention_probs) |
||||
|
||||
# context layer shape: [b, num_heads, sq, hn] |
||||
output_size = (value_layer.size(1), |
||||
value_layer.size(2), |
||||
query_layer.size(0), |
||||
value_layer.size(3)) |
||||
# |
||||
# # change view [sk, b * num_heads, hn] |
||||
value_layer = value_layer.contiguous().view(value_layer.size(0), |
||||
output_size[0] * output_size[1], -1) |
||||
|
||||
# # change view [b * num_heads, sq, sk] |
||||
attention_probs = attention_probs.view(attention_probs.size(0) * attention_probs.size(1), |
||||
attention_probs.size(2), |
||||
attention_probs.size(3)) |
||||
|
||||
# matmul: [b*num_heads, sq, hn] |
||||
# context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1)) |
||||
context_layer = RingAV.apply( |
||||
attention_probs, |
||||
value_layer.transpose(0, 1).contiguous(), |
||||
batch_size, |
||||
self.num_attention_heads, |
||||
self.hidden_size_per_attention_head, |
||||
sub_seq_length |
||||
) |
||||
|
||||
# # change view [b, num_heads, sq, hn] |
||||
context_layer = context_layer.view(*output_size) |
||||
|
||||
# # [b, np, sq, hn] --> [sq, b, np, hn] |
||||
context_layer = context_layer.permute(2, 0, 1, 3).contiguous() |
||||
|
||||
# # [sq, b, np, hn] --> [sq, b, hp] |
||||
new_context_layer_shape = context_layer.size()[:-2] + ( |
||||
self.hidden_size_per_attention_head * self.num_attention_heads,) |
||||
context_layer = context_layer.view(*new_context_layer_shape) |
||||
|
||||
# context_layer = context_layer.transpose(1, 0).contiguous() |
||||
output = self.dense(context_layer) |
||||
bias = self.dense.bias |
||||
|
||||
return output, bias |
@ -0,0 +1,3 @@
|
||||
from .layers import ViTBlock |
||||
|
||||
__all__ = ['ViTBlock'] |
@ -0,0 +1,59 @@
|
||||
#!/usr/bin/env python |
||||
# -*- encoding: utf-8 -*- |
||||
|
||||
from torch import nn as nn |
||||
|
||||
from colossalai.builder import build_layer |
||||
from colossalai.registry import LAYERS |
||||
|
||||
|
||||
@LAYERS.register_module |
||||
class ViTBlock(nn.Module): |
||||
"""Vision Transformer block |
||||
|
||||
:param attention_cfg: config of attention layer |
||||
:type attention_cfg: dict |
||||
:param droppath_cfg: config of drop path |
||||
:type droppath_cfg: dict |
||||
:param mlp_cfg: config of MLP layer |
||||
:type mlp_cfg: dict |
||||
:param norm_cfg: config of normlization layer |
||||
:type norm_cfg: dict |
||||
""" |
||||
|
||||
def __init__(self, |
||||
attention_cfg: dict, |
||||
droppath_cfg: dict, |
||||
mlp_cfg: dict, |
||||
norm_cfg: dict, |
||||
): |
||||
super().__init__() |
||||
self.norm1 = build_layer(norm_cfg) |
||||
self.attn = build_layer(attention_cfg) |
||||
self.drop_path = build_layer( |
||||
droppath_cfg) if droppath_cfg['drop_path'] > 0. else nn.Identity() |
||||
self.norm2 = build_layer(norm_cfg) |
||||
self.mlp = build_layer(mlp_cfg) |
||||
|
||||
def forward(self, x): |
||||
x = x + self.drop_path(self.attn(self.norm1(x))) |
||||
x = x + self.drop_path(self.mlp(self.norm2(x))) |
||||
|
||||
# x_ = x |
||||
# x_ = self.norm1(x_) |
||||
# if self.checkpoint: |
||||
# x_ = checkpoint(self.attn, x_) |
||||
# else: |
||||
# x_ = self.attn(x_) |
||||
# x_ = self.drop_path(x_) |
||||
# x = x + x_ |
||||
# |
||||
# x_ = x |
||||
# x_ = self.norm2(x_) |
||||
# if self.checkpoint: |
||||
# x_ = checkpoint(self.mlp, x_) |
||||
# else: |
||||
# x_ = self.mlp(x_) |
||||
# x_ = self.drop_path(x_) |
||||
# x = x + x_ |
||||
return x |
@ -0,0 +1,5 @@
|
||||
from .basic_block import ResNetBasicBlock |
||||
from .bottleneck import ResNetBottleneck |
||||
from .reslayer import ResLayer |
||||
|
||||
__all__ = ['ResLayer', 'ResNetBottleneck', 'ResNetBasicBlock'] |
@ -0,0 +1,64 @@
|
||||
#!/usr/bin/env python |
||||
# -*- encoding: utf-8 -*- |
||||
|
||||
from typing import Optional, Callable |
||||
|
||||
import torch.nn as nn |
||||
from torch import Tensor |
||||
|
||||
from colossalai.registry import LAYERS |
||||
from .conv import conv3x3 |
||||
|
||||
|
||||
@LAYERS.register_module |
||||
class ResNetBasicBlock(nn.Module): |
||||
"""Basic ResNet block |
||||
""" |
||||
expansion: int = 1 |
||||
|
||||
def __init__( |
||||
self, |
||||
inplanes: int, |
||||
planes: int, |
||||
stride: int = 1, |
||||
downsample: Optional[nn.Module] = None, |
||||
groups: int = 1, |
||||
base_width: int = 64, |
||||
dilation: int = 1, |
||||
norm_layer: Optional[Callable[..., nn.Module]] = None |
||||
) -> None: |
||||
super().__init__() |
||||
if norm_layer is None: |
||||
norm_layer = nn.BatchNorm2d |
||||
if groups != 1 or base_width != 64: |
||||
raise ValueError( |
||||
'BasicBlock only supports groups=1 and base_width=64') |
||||
if dilation > 1: |
||||
raise NotImplementedError( |
||||
"Dilation > 1 not supported in BasicBlock") |
||||
# Both self.conv1 and self.downsample layers downsample the input when stride != 1 |
||||
self.conv1 = conv3x3(inplanes, planes, stride) |
||||
self.bn1 = norm_layer(planes) |
||||
self.relu = nn.ReLU(inplace=True) |
||||
self.conv2 = conv3x3(planes, planes) |
||||
self.bn2 = norm_layer(planes) |
||||
self.downsample = downsample |
||||
self.stride = stride |
||||
|
||||
def forward(self, x: Tensor) -> Tensor: |
||||
identity = x |
||||
|
||||
out = self.conv1(x) |
||||
out = self.bn1(out) |
||||
out = self.relu(out) |
||||
|
||||
out = self.conv2(out) |
||||
out = self.bn2(out) |
||||
|
||||
if self.downsample is not None: |
||||
identity = self.downsample(x) |
||||
|
||||
out += identity |
||||
out = self.relu(out) |
||||
|
||||
return out |
@ -0,0 +1,69 @@
|
||||
#!/usr/bin/env python |
||||
# -*- encoding: utf-8 -*- |
||||
|
||||
from typing import Optional, Callable |
||||
|
||||
import torch.nn as nn |
||||
from torch import Tensor |
||||
|
||||
from colossalai.registry import LAYERS |
||||
from .conv import conv3x3, conv1x1 |
||||
|
||||
|
||||
@LAYERS.register_module |
||||
class ResNetBottleneck(nn.Module): |
||||
# Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2) |
||||
# while original implementation places the stride at the first 1x1 convolution(self.conv1) |
||||
# according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385. |
||||
# This variant is also known as ResNet V1.5 and improves accuracy according to |
||||
# https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch. |
||||
|
||||
expansion: int = 4 |
||||
|
||||
def __init__( |
||||
self, |
||||
inplanes: int, |
||||
planes: int, |
||||
stride: int = 1, |
||||
downsample: Optional[nn.Module] = None, |
||||
groups: int = 1, |
||||
base_width: int = 64, |
||||
dilation: int = 1, |
||||
norm_layer: Optional[Callable[..., nn.Module]] = None |
||||
) -> None: |
||||
super().__init__() |
||||
if norm_layer is None: |
||||
norm_layer = nn.BatchNorm2d |
||||
width = int(planes * (base_width / 64.)) * groups |
||||
# Both self.conv2 and self.downsample layers downsample the input when stride != 1 |
||||
self.conv1 = conv1x1(inplanes, width) |
||||
self.bn1 = norm_layer(width) |
||||
self.conv2 = conv3x3(width, width, stride, groups, dilation) |
||||
self.bn2 = norm_layer(width) |
||||
self.conv3 = conv1x1(width, planes * self.expansion) |
||||
self.bn3 = norm_layer(planes * self.expansion) |
||||
self.relu = nn.ReLU(inplace=True) |
||||
self.downsample = downsample |
||||
self.stride = stride |
||||
|
||||
def forward(self, x: Tensor) -> Tensor: |
||||
identity = x |
||||
|
||||
out = self.conv1(x) |
||||
out = self.bn1(out) |
||||
out = self.relu(out) |
||||
|
||||
out = self.conv2(out) |
||||
out = self.bn2(out) |
||||
out = self.relu(out) |
||||
|
||||
out = self.conv3(out) |
||||
out = self.bn3(out) |
||||
|
||||
if self.downsample is not None: |
||||
identity = self.downsample(x) |
||||
|
||||
out += identity |
||||
out = self.relu(out) |
||||
|
||||
return out |
@ -0,0 +1,15 @@
|
||||
#!/usr/bin/env python |
||||
# -*- encoding: utf-8 -*- |
||||
|
||||
import torch.nn as nn |
||||
|
||||
|
||||
def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d: |
||||
"""3x3 convolution with padding""" |
||||
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, |
||||
padding=dilation, groups=groups, bias=False, dilation=dilation) |
||||
|
||||
|
||||
def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d: |
||||
"""1x1 convolution""" |
||||
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) |
@ -0,0 +1,63 @@
|
||||
#!/usr/bin/env python |
||||
# -*- encoding: utf-8 -*- |
||||
|
||||
import torch.nn as nn |
||||
|
||||
from colossalai.registry import LAYERS |
||||
from .conv import conv1x1 |
||||
|
||||
|
||||
@LAYERS.register_module |
||||
class ResLayer(nn.Module): |
||||
|
||||
def __init__(self, |
||||
block_type: str, |
||||
norm_layer_type: str, |
||||
inplanes: int, |
||||
planes: int, |
||||
blocks: int, |
||||
groups: int, |
||||
base_width: int, |
||||
stride: int = 1, |
||||
dilation: int = 1, |
||||
dilate: bool = False, |
||||
): |
||||
super().__init__() |
||||
self.block = LAYERS.get_module(block_type) |
||||
self.norm_layer = LAYERS.get_module(norm_layer_type) |
||||
self.inplanes = inplanes |
||||
self.planes = planes |
||||
self.blocks = blocks |
||||
self.groups = groups |
||||
self.dilation = dilation |
||||
self.base_width = base_width |
||||
self.dilate = dilate |
||||
self.stride = stride |
||||
self.layer = self._make_layer() |
||||
|
||||
def _make_layer(self): |
||||
norm_layer = self.norm_layer |
||||
downsample = None |
||||
previous_dilation = self.dilation |
||||
if self.dilate: |
||||
self.dilation *= self.stride |
||||
self.stride = 1 |
||||
if self.stride != 1 or self.inplanes != self.planes * self.block.expansion: |
||||
downsample = nn.Sequential( |
||||
conv1x1(self.inplanes, self.planes * self.block.expansion, self.stride), |
||||
norm_layer(self.planes * self.block.expansion), |
||||
) |
||||
|
||||
layers = [] |
||||
layers.append(self.block(self.inplanes, self.planes, self.stride, downsample, self.groups, |
||||
self.base_width, previous_dilation, norm_layer)) |
||||
self.inplanes = self.planes * self.block.expansion |
||||
for _ in range(1, self.blocks): |
||||
layers.append(self.block(self.inplanes, self.planes, groups=self.groups, |
||||
base_width=self.base_width, dilation=self.dilation, |
||||
norm_layer=norm_layer)) |
||||
|
||||
return nn.Sequential(*layers) |
||||
|
||||
def forward(self, x): |
||||
return self.layer(x) |
@ -0,0 +1,7 @@
|
||||
from .layers import (VanillaViTBlock, VanillaViTMLP, VanillaViTPatchEmbedding, |
||||
VanillaViTAttention, VanillaViTDropPath, VanillaViTHead) |
||||
|
||||
__all__ = [ |
||||
'VanillaViTBlock', 'VanillaViTMLP', 'VanillaViTPatchEmbedding', |
||||
'VanillaViTAttention', 'VanillaViTDropPath', 'VanillaViTHead' |
||||
] |
@ -0,0 +1,244 @@
|
||||
import collections.abc |
||||
from itertools import repeat |
||||
|
||||
import torch |
||||
from torch import nn as nn |
||||
|
||||
from colossalai.registry import LAYERS |
||||
|
||||
|
||||
# From PyTorch internals |
||||
def _ntuple(n): |
||||
def parse(x): |
||||
if isinstance(x, collections.abc.Iterable): |
||||
return x |
||||
return tuple(repeat(x, n)) |
||||
|
||||
return parse |
||||
|
||||
|
||||
to_2tuple = _ntuple(2) |
||||
|
||||
|
||||
@LAYERS.register_module |
||||
class VanillaViTPatchEmbedding(nn.Module): |
||||
""" 2D Image to Patch Embedding |
||||
""" |
||||
|
||||
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None, flatten=True, drop=0.): |
||||
super().__init__() |
||||
img_size = to_2tuple(img_size) |
||||
patch_size = to_2tuple(patch_size) |
||||
self.img_size = img_size |
||||
self.patch_size = patch_size |
||||
self.grid_size = (img_size[0] // patch_size[0], |
||||
img_size[1] // patch_size[1]) |
||||
self.num_patches = self.grid_size[0] * self.grid_size[1] |
||||
self.flatten = flatten |
||||
|
||||
self.proj = nn.Conv2d(in_chans, embed_dim, |
||||
kernel_size=patch_size, stride=patch_size) |
||||
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() |
||||
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) |
||||
self.pos_embed = nn.Parameter(torch.zeros(1, self.num_patches + 1, embed_dim)) |
||||
self.pos_drop = nn.Dropout(p=drop) |
||||
|
||||
def forward(self, x): |
||||
B, C, H, W = x.shape |
||||
assert H == self.img_size[0] and W == self.img_size[1], \ |
||||
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." |
||||
x = self.proj(x) |
||||
if self.flatten: |
||||
x = x.flatten(2).transpose(1, 2) # BCHW -> BNC |
||||
x = self.norm(x) |
||||
cls_token = self.cls_token.expand(x.shape[0], -1, -1) |
||||
x = torch.cat((cls_token, x), dim=1) |
||||
x = self.pos_drop(x + self.pos_embed) |
||||
return x |
||||
|
||||
|
||||
@LAYERS.register_module |
||||
class VanillaViTMLP(nn.Module): |
||||
""" MLP as used in Vision Transformer, MLP-Mixer and related networks |
||||
""" |
||||
|
||||
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): |
||||
super().__init__() |
||||
out_features = out_features or in_features |
||||
hidden_features = hidden_features or in_features |
||||
self.fc1 = nn.Linear(in_features, hidden_features) |
||||
self.act = act_layer() |
||||
self.fc2 = nn.Linear(hidden_features, out_features) |
||||
self.drop = nn.Dropout(drop) |
||||
|
||||
def forward(self, x): |
||||
x = self.fc1(x) |
||||
x = self.act(x) |
||||
x = self.drop(x) |
||||
x = self.fc2(x) |
||||
x = self.drop(x) |
||||
return x |
||||
|
||||
|
||||
def drop_path(x, drop_prob: float = 0., training: bool = False): |
||||
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). |
||||
|
||||
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, |
||||
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... |
||||
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for |
||||
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use |
||||
'survival rate' as the argument. |
||||
|
||||
""" |
||||
if drop_prob == 0. or not training: |
||||
return x |
||||
keep_prob = 1 - drop_prob |
||||
# work with diff dim tensors, not just 2D ConvNets |
||||
shape = (x.shape[0],) + (1,) * (x.ndim - 1) |
||||
random_tensor = keep_prob + \ |
||||
torch.rand(shape, dtype=x.dtype, device=x.device) |
||||
random_tensor.floor_() # binarize |
||||
output = x.div(keep_prob) * random_tensor |
||||
return output |
||||
|
||||
|
||||
@LAYERS.register_module |
||||
class VanillaViTDropPath(nn.Module): |
||||
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). |
||||
""" |
||||
|
||||
def __init__(self, drop_prob=0.): |
||||
super().__init__() |
||||
self.drop_prob = drop_prob |
||||
|
||||
def forward(self, x): |
||||
return drop_path(x, self.drop_prob, self.training) |
||||
|
||||
|
||||
@LAYERS.register_module |
||||
class VanillaViTAttention(nn.Module): |
||||
"""Vanilla attention layer of Vision Transformer |
||||
|
||||
:param dim: dimension of input tensor |
||||
:type dim: int |
||||
:param num_heads: number of attention heads, defaults to 8 |
||||
:type num_heads: int, optional |
||||
:param qkv_bias: enable bias for qkv if True, defaults to False |
||||
:type qkv_bias: bool, optional |
||||
:param attn_drop: dropout probability for attention layer, defaults to 0. |
||||
:type attn_drop: float, optional |
||||
:param proj_drop: dropout probability for linear layer, defaults to 0. |
||||
:type proj_drop: float, optional |
||||
""" |
||||
|
||||
def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.): |
||||
super().__init__() |
||||
self.num_heads = num_heads |
||||
head_dim = dim // num_heads |
||||
self.scale = head_dim ** -0.5 |
||||
|
||||
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) |
||||
self.attn_drop = nn.Dropout(attn_drop) |
||||
self.proj = nn.Linear(dim, dim) |
||||
self.proj_drop = nn.Dropout(proj_drop) |
||||
|
||||
def forward(self, x): |
||||
B, N, C = x.shape |
||||
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // |
||||
self.num_heads).permute(2, 0, 3, 1, 4) |
||||
# make torchscript happy (cannot use tensor as tuple) |
||||
q, k, v = qkv[0], qkv[1], qkv[2] |
||||
|
||||
attn = (q @ k.transpose(-2, -1)) * self.scale |
||||
attn = attn.softmax(dim=-1) |
||||
attn = self.attn_drop(attn) |
||||
|
||||
x = (attn @ v).transpose(1, 2).reshape(B, N, C) |
||||
x = self.proj(x) |
||||
x = self.proj_drop(x) |
||||
return x |
||||
|
||||
|
||||
@LAYERS.register_module |
||||
class VanillaViTBlock(nn.Module): |
||||
|
||||
"""Vanilla Vision Transformer block |
||||
|
||||
:param dim: dimension of input tensor |
||||
:type dim: int |
||||
:param num_heads: number of attention heads |
||||
:type num_heads: int |
||||
:param mlp_ratio: hidden size of MLP divided by embedding dim, defaults to 4. |
||||
:type mlp_ratio: float, optional |
||||
:param qkv_bias: enable bias for qkv if True, defaults to False |
||||
:type qkv_bias: bool, optional |
||||
:param drop: dropout probability, defaults to 0. |
||||
:type drop: float, optional |
||||
:param attn_drop: dropout probability for attention layer, defaults to 0. |
||||
:type attn_drop: float, optional |
||||
:param drop_path: drop path probability, defaults to 0. |
||||
:type drop_path: float, optional |
||||
:param act_layer: activation function, defaults to nn.GELU |
||||
:type act_layer: torch.nn.Module, optional |
||||
:param norm_layer: normalization layer, defaults to nn.LayerNorm |
||||
:type norm_layer: torch.nn.Module, optional |
||||
""" |
||||
|
||||
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., |
||||
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): |
||||
super().__init__() |
||||
self.norm1 = norm_layer(dim) |
||||
self.attn = LAYERS.get_module('VanillaViTAttention')(dim, |
||||
num_heads=num_heads, |
||||
qkv_bias=qkv_bias, |
||||
attn_drop=attn_drop, |
||||
proj_drop=drop) |
||||
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here |
||||
self.drop_path = LAYERS.get_module('VanillaViTDropPath')( |
||||
drop_path) if drop_path > 0. else nn.Identity() |
||||
self.norm2 = norm_layer(dim) |
||||
mlp_hidden_dim = int(dim * mlp_ratio) |
||||
self.mlp = LAYERS.get_module('VanillaViTMLP')(in_features=dim, |
||||
hidden_features=mlp_hidden_dim, |
||||
act_layer=act_layer, |
||||
drop=drop) |
||||
|
||||
def forward(self, x): |
||||
x = x + self.drop_path(self.attn(self.norm1(x))) |
||||
x = x + self.drop_path(self.mlp(self.norm2(x))) |
||||
return x |
||||
|
||||
|
||||
@LAYERS.register_module |
||||
class VanillaViTHead(nn.Module): |
||||
"""Output layer of vanilla Vision Transformer |
||||
|
||||
:param in_features: size of input tensor |
||||
:type in_features: int |
||||
:param intermediate_features: hidden size |
||||
:type intermediate_features: int |
||||
:param out_features: size of output tensor |
||||
:type out_features: int |
||||
:param bias: whether to add bias, defaults to True |
||||
:type bias: bool, optional |
||||
""" |
||||
|
||||
def __init__(self, |
||||
in_features, |
||||
intermediate_features, |
||||
out_features, |
||||
bias=True |
||||
): |
||||
super().__init__() |
||||
self.linear_1 = nn.Linear( |
||||
in_features, intermediate_features, bias=bias) |
||||
self.act = nn.Tanh() |
||||
self.linear_2 = nn.Linear( |
||||
intermediate_features, out_features, bias=bias) |
||||
|
||||
def forward(self, x): |
||||
x = x[:, 0, :].squeeze(1) |
||||
x = self.linear_1(x) |
||||
x = self.act(x) |
||||
x = self.linear_2(x) |
||||
return x |
@ -0,0 +1,3 @@
|
||||
from .lambda_wrapper import LambdaWrapper |
||||
|
||||
__all__ = ['LambdaWrapper'] |
@ -0,0 +1,37 @@
|
||||
#!/usr/bin/env python |
||||
# -*- encoding: utf-8 -*- |
||||
|
||||
import torch.nn as nn |
||||
|
||||
from colossalai.builder import build_layer |
||||
from colossalai.registry import LAYERS |
||||
|
||||
|
||||
@LAYERS.register_module |
||||
class LambdaWrapper(nn.Module): |
||||
"""Wrap a function to nn.Module, which takes a config of layers and can fully access them |
||||
|
||||
:param func: user customed function |
||||
:type func: Callable |
||||
:param layers_cfg: config of layers, defaults to None |
||||
:type layers_cfg: dict, optional |
||||
""" |
||||
|
||||
def __init__(self, func, layers_cfg: dict = None): |
||||
super().__init__() |
||||
self.func = func |
||||
self.layers = self._build_layers(layers_cfg) |
||||
|
||||
def _build_layers(self, layers_cfg: dict): |
||||
if layers_cfg is None: |
||||
return None |
||||
else: |
||||
layers = [] |
||||
|
||||
for cfg in layers_cfg: |
||||
layer = build_layer(cfg) |
||||
layers.append(layer) |
||||
return layers |
||||
|
||||
def forward(self, *args, **kwargs): |
||||
return self.func(self, *args, **kwargs) |
@ -0,0 +1,6 @@
|
||||
from .base_loss import BaseLoss |
||||
from .cross_entropy_2d import CrossEntropyLoss2D |
||||
from .cross_entropy_2p5d import CrossEntropyLoss2p5D |
||||
from .cross_entropy_3d import CrossEntropyLoss3D |
||||
|
||||
__all__ = ['CrossEntropyLoss2D', 'CrossEntropyLoss2p5D', 'CrossEntropyLoss3D'] |
@ -0,0 +1,13 @@
|
||||
#!/usr/bin/env python |
||||
# -*- encoding: utf-8 -*- |
||||
|
||||
from abc import ABC, abstractmethod |
||||
|
||||
|
||||
class BaseLoss(ABC): |
||||
"""Absctract loss class |
||||
""" |
||||
|
||||
@abstractmethod |
||||
def calc_loss(self, *args, **kwargs): |
||||
pass |
@ -0,0 +1,120 @@
|
||||
#!/usr/bin/env python |
||||
# -*- encoding: utf-8 -*- |
||||
|
||||
import torch |
||||
import torch.nn.functional as F |
||||
from torch.nn.modules.loss import _Loss |
||||
|
||||
from colossalai.context.parallel_mode import ParallelMode |
||||
from colossalai.core import global_context as gpc |
||||
from colossalai.nn.layer.parallel_1d._utils import vocab_range_from_per_partition_vocab_size |
||||
|
||||
|
||||
class _VocabParallelCrossEntropy_1D(torch.autograd.Function): |
||||
|
||||
@staticmethod |
||||
def forward(ctx, vocab_parallel_logits, target): |
||||
# Maximum value along vocab dimension across all GPUs. |
||||
logits_max = torch.max(vocab_parallel_logits, dim=-1)[0] |
||||
torch.distributed.all_reduce(logits_max, |
||||
op=torch.distributed.ReduceOp.MAX, |
||||
group=gpc.get_group(ParallelMode.PARALLEL_1D)) |
||||
# Subtract the maximum value. |
||||
vocab_parallel_logits.sub_(logits_max.unsqueeze(dim=-1)) |
||||
|
||||
# Get the partition's vocab indecies |
||||
partition_vocab_size = vocab_parallel_logits.size()[-1] |
||||
rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D) |
||||
world_size = gpc.get_world_size(ParallelMode.PARALLEL_1D) |
||||
vocab_start_index, vocab_end_index = vocab_range_from_per_partition_vocab_size( |
||||
partition_vocab_size, rank, world_size) |
||||
|
||||
# Create a mask of valid vocab ids (1 means it needs to be masked). |
||||
target_mask = (target < vocab_start_index) | (target >= vocab_end_index) |
||||
masked_target = target.clone() - vocab_start_index |
||||
masked_target[target_mask] = 0 |
||||
|
||||
# Get predicted-logits = logits[target]. |
||||
# For Simplicity, we convert logits to a 2-D tensor with size |
||||
# [*, partition-vocab-size] and target to a 1-D tensor of size [*]. |
||||
logits_2d = vocab_parallel_logits.view(-1, partition_vocab_size) |
||||
masked_target_1d = masked_target.view(-1) |
||||
arange_1d = torch.arange(start=0, end=logits_2d.size()[0], |
||||
device=logits_2d.device) |
||||
predicted_logits_1d = logits_2d[arange_1d, masked_target_1d] |
||||
predicted_logits_1d = predicted_logits_1d.clone().contiguous() |
||||
predicted_logits = predicted_logits_1d.view_as(target) |
||||
predicted_logits[target_mask] = 0.0 |
||||
# All reduce is needed to get the chunks from other GPUs. |
||||
torch.distributed.all_reduce(predicted_logits, |
||||
op=torch.distributed.ReduceOp.SUM, |
||||
group=gpc.get_group(ParallelMode.PARALLEL_1D)) |
||||
|
||||
# Sum of exponential of logits along vocab dimension across all GPUs. |
||||
exp_logits = vocab_parallel_logits |
||||
torch.exp(vocab_parallel_logits, out=exp_logits) |
||||
sum_exp_logits = exp_logits.sum(dim=-1) |
||||
torch.distributed.all_reduce(sum_exp_logits, |
||||
op=torch.distributed.ReduceOp.SUM, |
||||
group=gpc.get_group(ParallelMode.PARALLEL_1D)) |
||||
|
||||
# Loss = log(sum(exp(logits))) - predicted-logit. |
||||
loss = torch.log(sum_exp_logits) - predicted_logits |
||||
|
||||
# Store softmax, target-mask and masked-target for backward pass. |
||||
exp_logits.div_(sum_exp_logits.unsqueeze(dim=-1)) |
||||
ctx.save_for_backward(exp_logits, target_mask, masked_target_1d) |
||||
|
||||
return loss |
||||
|
||||
@staticmethod |
||||
def backward(ctx, grad_output): |
||||
# Retreive tensors from the forward path. |
||||
softmax, target_mask, masked_target_1d = ctx.saved_tensors |
||||
|
||||
# All the inputs have softmax as thier gradient. |
||||
grad_input = softmax |
||||
# For simplicity, work with the 2D gradient. |
||||
partition_vocab_size = softmax.size()[-1] |
||||
grad_2d = grad_input.view(-1, partition_vocab_size) |
||||
|
||||
# Add the gradient from matching classes. |
||||
arange_1d = torch.arange(start=0, end=grad_2d.size()[0], |
||||
device=grad_2d.device) |
||||
grad_2d[arange_1d, masked_target_1d] -= ( |
||||
1.0 - target_mask.view(-1).float()) |
||||
|
||||
# Finally elementwise multiplication with the output gradients. |
||||
grad_input.mul_(grad_output.unsqueeze(dim=-1)) |
||||
|
||||
return grad_input, None |
||||
|
||||
|
||||
class LmLoss1D(_Loss): |
||||
|
||||
def forward(self, lm_logits, lm_labels, loss_mask): |
||||
lm_loss = _VocabParallelCrossEntropy_1D.apply(lm_logits, lm_labels) |
||||
lm_loss = torch.sum( |
||||
lm_loss.view(-1) * loss_mask.reshape(-1)) / loss_mask.sum() |
||||
return lm_loss |
||||
|
||||
|
||||
class SopLoss1D(_Loss): |
||||
|
||||
def forward(self, sop_logits, sentence_order): |
||||
sop_loss = F.cross_entropy(sop_logits.view(-1, 2).float(), |
||||
sentence_order.view(-1), |
||||
ignore_index=-1) |
||||
return sop_loss |
||||
|
||||
|
||||
class BERTDualHeadLoss(_Loss): |
||||
|
||||
def __init__(self): |
||||
self.lm_loss = LmLoss1D() |
||||
self.sop_loss = SopLoss1D() |
||||
|
||||
def forward(self, lm_logits, sop_logits, lm_labels, loss_mask, sentence_order): |
||||
lm_loss = self.lm_loss(lm_logits, lm_labels, loss_mask) |
||||
sop_loss = self.sop_loss(sop_logits, sentence_order) |
||||
return lm_loss + sop_loss |
@ -0,0 +1,128 @@
|
||||
import torch |
||||
import torch.distributed as dist |
||||
from torch.nn.modules.loss import _Loss |
||||
|
||||
from colossalai.context.parallel_mode import ParallelMode |
||||
from colossalai.core import global_context as gpc |
||||
from colossalai.nn.layer.parallel_2d._utils import assert_summa_initialization, get_summa_dim_from_env |
||||
from colossalai.registry import LOSSES |
||||
from colossalai.utils import get_current_device |
||||
|
||||
|
||||
class _ParallelCrossEntropyLossFunction_2D(torch.autograd.Function): |
||||
### Modified based on megatron.mpu.cross_entropy ### |
||||
|
||||
@staticmethod |
||||
def forward(ctx, logits, targets): |
||||
# logits: [b/q, h/q] |
||||
# labels: [b/q] |
||||
# loss: [b/q] |
||||
# vocab_parallel_logits: [b/q, s, v/q] |
||||
# target: [b/q, s] |
||||
logits_max = torch.max(logits, dim=-1)[0] |
||||
torch.distributed.all_reduce( |
||||
logits_max, |
||||
op=torch.distributed.ReduceOp.MAX, |
||||
group=gpc.get_group(ParallelMode.PARALLEL_2D_ROW)) |
||||
# Subtract the maximum value. |
||||
# vocab_parallel_logits.sub_(logits_max.unsqueeze(dim=-1)) |
||||
logits = logits - logits_max.unsqueeze(dim=-1) |
||||
|
||||
vocab_size = logits.size(-1) |
||||
rank = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW) |
||||
vocab_start = rank * (vocab_size) |
||||
vocab_end = (rank + 1) * (vocab_size) - 1 |
||||
|
||||
target_mask = (targets < vocab_start) | (targets > vocab_end) |
||||
|
||||
masked_target = targets.clone() - vocab_start |
||||
masked_target[target_mask] = 0 |
||||
arange_1d = torch.arange( |
||||
start=0, end=logits.size()[0], |
||||
) |
||||
predicted_logits = logits[arange_1d, masked_target] |
||||
predicted_logits[target_mask] = 0. |
||||
dist.all_reduce(predicted_logits, group=gpc.get_group( |
||||
ParallelMode.PARALLEL_2D_ROW)) |
||||
|
||||
exp_logits = torch.exp(logits) |
||||
sum_exp_logits = exp_logits.sum(dim=1) |
||||
dist.all_reduce(sum_exp_logits, group=gpc.get_group( |
||||
ParallelMode.PARALLEL_2D_ROW)) |
||||
|
||||
loss = torch.log(sum_exp_logits) - predicted_logits |
||||
|
||||
exp_logits.div_(sum_exp_logits.unsqueeze(dim=-1)) |
||||
ctx.save_for_backward(exp_logits, target_mask, masked_target) |
||||
|
||||
return loss |
||||
|
||||
@staticmethod |
||||
def backward(ctx, output_grad): |
||||
# Retreive tensors from the forward path. |
||||
softmax, target_mask, masked_target = ctx.saved_tensors |
||||
|
||||
# All the inputs have softmax as their gradient. |
||||
grad_input = softmax |
||||
|
||||
# For simplicity, work with the 2D gradient. |
||||
partition_vocab_size = softmax.size()[-1] |
||||
grad_2d = grad_input.view(-1, partition_vocab_size) |
||||
|
||||
# Add the gradient from matching classes. |
||||
arange_1d = torch.arange(start=0, end=grad_2d.size()[0], |
||||
device=get_current_device()) |
||||
grad_2d[arange_1d, |
||||
masked_target] -= (1.0 - target_mask.view(-1).float()) |
||||
|
||||
# Finally elementwise multiplication with the output gradients. |
||||
grad_input.mul_(output_grad.unsqueeze(dim=-1)) |
||||
|
||||
return grad_input, None |
||||
|
||||
|
||||
class _ReduceByColumn(torch.autograd.Function): |
||||
"""All-reduce the input from the model parallel region.""" |
||||
|
||||
@staticmethod |
||||
def symbolic(graph, input_): |
||||
dist.all_reduce(input_, group=gpc.get_group( |
||||
ParallelMode.PARALLEL_2D_COL)) |
||||
return input_ |
||||
|
||||
@staticmethod |
||||
def forward(ctx, input_): |
||||
dist.all_reduce(input_, group=gpc.get_group( |
||||
ParallelMode.PARALLEL_2D_COL)) |
||||
return input_ |
||||
|
||||
@staticmethod |
||||
def backward(ctx, grad_output): |
||||
return grad_output |
||||
|
||||
|
||||
@LOSSES.register_module |
||||
class CrossEntropyLoss2D(_Loss): |
||||
"""Cross entropy loss for 2D parallelism |
||||
|
||||
:param reduction: whether to average the loss, defaults to True |
||||
:type reduction: bool, optional |
||||
""" |
||||
|
||||
def __init__(self, reduction=True): |
||||
super().__init__() |
||||
assert_summa_initialization() |
||||
self.summa_dim = get_summa_dim_from_env() |
||||
self.row_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) |
||||
self.reduction_mean = reduction |
||||
|
||||
def forward(self, logits, targets): |
||||
targets = targets.chunk(self.summa_dim, dim=0)[self.row_rank] |
||||
loss = _ParallelCrossEntropyLossFunction_2D.apply( |
||||
logits, targets, |
||||
) |
||||
if self.reduction_mean: |
||||
loss = _ReduceByColumn.apply(loss) / self.summa_dim |
||||
dist_loss = loss.mean() |
||||
|
||||
return dist_loss |
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in new issue