Support TP-compatible Torch AMP and Update trainer API (#27)

* Add gradient accumulation, fix lr scheduler

* fix FP16 optimizer and adapted torch amp with tensor parallel (#18)

* fixed bugs in compatibility between torch amp and tensor parallel and performed some minor fixes

* fixed trainer

* Revert "fixed trainer"

This reverts commit 2e0b0b7699.

* improved consistency between trainer, engine and schedule (#23)

Co-authored-by: 1SAA <c2h214748@gmail.com>

Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: ver217 <lhx0217@gmail.com>
pull/28/head
Frank Lee 2021-11-18 19:45:06 +08:00 committed by GitHub
parent 2b05de4c64
commit 3defa32aee
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
80 changed files with 2194 additions and 1584 deletions

View File

@ -42,26 +42,18 @@ pip install -v --no-cache-dir --global-option="--cuda_ext" .
```python
import colossalai
from colossalai.engine import Engine
from colossalai.trainer import Trainer
from colossalai.core import global_context as gpc
model, train_dataloader, test_dataloader, criterion, optimizer, schedule, lr_scheduler = colossalai.initialize()
engine = Engine(
model=model,
criterion=criterion,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
schedule=schedule
)
engine, train_dataloader, test_dataloader = colossalai.initialize()
trainer = Trainer(engine=engine,
hooks_cfg=gpc.config.hooks,
verbose=True)
trainer.fit(
train_dataloader=train_dataloader,
test_dataloader=test_dataloader,
max_epochs=gpc.config.num_epochs,
epochs=gpc.config.num_epochs,
hooks_cfg=gpc.config.hooks,
display_progress=True,
test_interval=5
)

View File

@ -1,2 +1,10 @@
from .builder import *
from .builder import (build_schedule, build_lr_scheduler, build_model, build_optimizer, build_optimizer_wrapper,
build_layer, build_loss, build_hooks, build_dataset, build_transform, build_data_sampler,
build_gradient_handler)
from .pipeline import ModelInitializer
__all__ = [
'build_schedule', 'build_lr_scheduler', 'build_model', 'build_optimizer', 'build_optimizer_wrapper',
'build_layer', 'build_loss', 'build_hooks', 'build_dataset', 'build_transform', 'build_data_sampler',
'build_gradient_handler', 'ModelInitializer'
]

View File

@ -181,18 +181,6 @@ def build_transform(config):
return build_from_registry(config, TRANSFORMS)
def build_pipe_alloc_policy(config):
"""Returns a pipeline allocation policy object constructed from `config`.
:param config: A python dict or a :class:`colossalai.context.Config` object
containing information used in the construction of the return object
:type config: dict or :class:`colossalai.context.Config`
:return: A pipeline allocation policy object
:rtype:
"""
return build_from_registry(config, PIPE_ALLOC_POLICY)
def build_data_sampler(config, dataset):
"""Returns a data sampler object of :class:`colossalai.nn.data.sampler.BaseSampler`
constructed from `config`.
@ -235,7 +223,7 @@ def build_optimizer_wrapper(config, optimizer, model=None):
return OPTIMIZER_WRAPPERS.get_module(mod_type)(optimizer, **config_)
def build_lr_scheduler(config, optimizer, total_steps, num_steps_per_epoch):
def build_lr_scheduler(config, optimizer):
"""Returns a learning rate scheduler object of :class:`torch.optim.lr_scheduler`
constructed from `config`, `optimizer`, `total_steps` and `num_steps_per_epoch`.
@ -254,9 +242,16 @@ def build_lr_scheduler(config, optimizer, total_steps, num_steps_per_epoch):
"""
config_ = config.copy()
mod_type = config_.pop('type')
# warmup epochs will overwrite warmup steps
if 'warmup_epochs' in config_:
warmup_epochs = config_.pop('warmup_epochs')
config_['warmup_steps'] = int(num_steps_per_epoch * warmup_epochs)
return LR_SCHEDULERS.get_module(mod_type)(optimizer, total_steps, num_steps_per_epoch=num_steps_per_epoch,
**config_)
return LR_SCHEDULERS.get_module(mod_type)(optimizer, **config_)
def build_schedule(config):
"""Returns a schedule of :class:`colossalai.engine.schedule.BaseSchedule`.
:param config: A python dict or a :class:`colossalai.context.Config` object
containing information used in the construction of the return object
:type config: dict or :class:`colossalai.context.Config`
:return: An object of :class:`colossalai.engine.schedule.BaseSchedule`
:rtype: :class:`colossalai.engine.schedule.BaseSchedule`
"""
return build_from_registry(config, SCHEDULE)

View File

@ -1,7 +1,7 @@
from .amp_type import AMP_TYPE
from ._base_engine import Engine
from .gradient_handler import *
from .schedule import *
from .amp import *
__all__ = ['Engine']

View File

@ -1,7 +1,9 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from typing import Optional
from torch.nn import Module
from torch.nn.modules.loss import _Loss
from torch.optim import Optimizer
from colossalai.builder import build_gradient_handler
from colossalai.context import ParallelMode
@ -9,89 +11,103 @@ from colossalai.core import global_context as gpc
from colossalai.logging import get_global_dist_logger
from colossalai.nn import (ZeroRedundancyOptimizer_Level_2,
ZeroRedundancyOptimizer_Level_3)
from torch.nn import Module
from torch.nn.modules.loss import _Loss
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
from torch.utils.data import DataLoader
from .schedule import BaseSchedule, NoPipelineSchedule
from .schedule import BaseSchedule
class Engine:
"""Basic engine class for training and evaluation. It runs a specific process method
:meth:`step` which is based on the given :attr:`schedule` over each batch of a dataset.
It controls a iteration in training.
:param train_dataloader: Dataloader in training
:param test_dataloader: Dataloader in evaluation
:param model: The neural network model
:param criterion: Criterion for calculating loss
:param optimizer: Optimizer for updating the parameters
:param lr_scheduler: Learning rate scheduler ajusting learning rate during the training or evaluation
:param schedule: Running schedule in :meth:`step`
:type train_dataloader: DataLoader, optional
:type test_dataloader: DataLoader, optional
:param step_schedule: Running schedule in :meth:`step`
:param gradient_accumulation: Steps of gradient accumulation
:param gradient_clipping: The norm of gradient clipping
:type model: Module
:type criterion: _Loss, optional
:type optimizer: Optimizer, optional
:type lr_scheduler: _LRScheduler, optional
:type schedule: BaseSchedule, optional
:type optimizer: Optimizer
:type step_schedule: BaseSchedule, optional
:type gradient_accumulation: int, optional
:type gradient_clipping: float, optional
"""
def __init__(self,
train_dataloader: Optional[DataLoader] = None,
test_dataloader: Optional[DataLoader] = None,
model: Module = None,
criterion: _Loss = None,
optimizer: Optimizer = None,
lr_scheduler: Optional[_LRScheduler] = None,
schedule: BaseSchedule = None):
self.train_dataloader = train_dataloader
self.test_dataloader = test_dataloader
assert model is not None, "Engine requires a model"
self.model = model
self.criterion = criterion
self.optimizer = optimizer
self.lr_scheduler = lr_scheduler
self.schedule = schedule if schedule is not None \
else NoPipelineSchedule()
model: Module,
optimizer: Optimizer,
criterion: _Loss,
step_schedule: BaseSchedule,
gradient_handlers: list = None,
gradient_accumulation: int = 1,
gradient_clipping: float = 0.0,
):
self._model = model
self._optimizer = optimizer
self._criterion = criterion
self._schedule = step_schedule
# schedule initialize
self._schedule.initialize(model, optimizer)
# state
self.training = True # default
# gradient accumulation
assert gradient_accumulation > 0, 'gradient accumulation size must be larger than 0'
self._grad_accum_size = gradient_accumulation
self._grad_clip = gradient_clipping
self._logger = get_global_dist_logger()
# build gradient handler
self._gradient_handlers = []
gradient_handler_cfg = []
if hasattr(gpc.config, 'gradient_handler'):
assert isinstance(gpc.config.gradient_handler, list), \
if gradient_handlers is not None:
assert isinstance(gradient_handlers, list), \
f'argument gradient_handler_cfg expected type list, ' \
f'but got type {type(gpc.config.gradient_handler)}'
gradient_handler_cfg = gpc.config.gradient_handler
elif isinstance(self.optimizer, (ZeroRedundancyOptimizer_Level_2,
ZeroRedundancyOptimizer_Level_3)):
gradient_handler_cfg = [dict(type='ZeROGradientHandler')]
f'but got type {type(gradient_handlers)}'
elif isinstance(optimizer, (ZeroRedundancyOptimizer_Level_2,
ZeroRedundancyOptimizer_Level_3)):
gradient_handlers = [dict(type='ZeROGradientHandler')]
self._logger.info(
"Training with zero is detected, ZeROGradientHandler is automatically "
"added even though not specified in the configuration",
ranks=[0])
elif gpc.is_initialized(ParallelMode.DATA) and gpc.get_world_size(
ParallelMode.DATA) > 1:
gradient_handler_cfg = [dict(type='DataParallelGradientHandler')]
gradient_handlers = [dict(type='DataParallelGradientHandler')]
self._logger.info(
"Data parallel training is detected, DataParallelGradientHandler is automatically "
"added even though not specified in the configuration",
ranks=[0])
if len(gradient_handler_cfg) == 0:
if gradient_handlers is None:
self._logger.warning(
"No gradient handler is set up, please make sure you do not need "
"to all-reduce the gradients after a training step.",
ranks=[0])
for cfg in gradient_handler_cfg:
handler = build_gradient_handler(cfg, self.model, self.optimizer)
self._gradient_handlers.append(handler)
else:
for cfg in gradient_handlers:
handler = build_gradient_handler(cfg, model, optimizer)
self._gradient_handlers.append(handler)
self.schedule.initialize(self.train_dataloader, self.model,
self.criterion, self.optimizer,
self.lr_scheduler)
self.forward_only = False
@property
def model(self):
return self._model
@property
def optimizer(self):
return self._optimizer
@property
def criterion(self):
return self._criterion
@property
def schedule(self):
return self._schedule
@property
def gradient_accumulation(self):
return self._grad_accum_size
def handle_gradient(self):
"""Handles all-reduce operations of gradients across different parallel groups.
@ -99,72 +115,62 @@ class Engine:
for handler in self._gradient_handlers:
handler.handle_gradient()
def set_dataloader(self, data: DataLoader, train: bool = True):
"""Sets dataloader in training or evaluation.
:param data: Dataloader to be set
:param train: Set training dataloader if True, otherwise evaluation dataloader
:type data: DataLoader
:type train: bool
"""
if train:
self.train_dataloader = data
else:
self.test_dataloader = data
def get_model(self):
"""Returns the neural network model in the engine.
"""
return self.model
def get_optimizer(self):
"""Returns optimizier in the engine.
"""
return self.optimizer
def get_lr_scheduler(self):
"""Returns the learning rate scheduler in the engine.
"""
return self.lr_scheduler
def train(self):
"""Sets the model to training mode.
"""
self.forward_only = False
self.schedule.train(dataloader=self.train_dataloader, mode=True)
self.training = True
self._model.train()
def eval(self):
"""Sets the model to evaluation mode.
"""
self.forward_only = True
self.schedule.train(dataloader=self.test_dataloader, mode=False)
self.training = False
self._model.eval()
def is_train(self):
"""Returns True if it is in training, otherwise False.
"""
return not self.forward_only
def get_lr(self):
"""Gets current learning rate.
"""
return self.schedule.get_lr()
def step(self, return_loss=True):
def step(self,
data_iter,
is_last_iteration: bool = False,
return_loss=True):
"""A running step based on the schedule. Usually, it runs a training or
evaluation over a batch of dataset.
:param data_iter: Data iterator of the dataset
:param is_last_iteration: If True, this iteration is the last iteration in the epoch
:param return_loss: loss will be returned if True
:type return_loss: bool
:type data_iter: Iterator
:type is_last_iteration: bool, optional
:type return_loss: bool, optional
:return: (output, lablel, loss)
"""
self.schedule.zero_grad(forward_only=self.forward_only)
if self.training:
self._optimizer.zero_grad()
output, label, loss = self.schedule.forward_backward_step(
forward_only=self.forward_only, return_loss=return_loss)
# differentiate training and eval with grad accum
if self.training:
for i in range(self._grad_accum_size):
output, label, loss = self._schedule.forward_backward_step(
data_iter, self._model, self._criterion, self._optimizer,
forward_only=False,
grad_accum_size=self._grad_accum_size,
return_loss=return_loss)
if not self.forward_only:
# all reduce gradients
self.handle_gradient()
if i == self._grad_accum_size - 1:
# all reduce gradients
self.handle_gradient()
self._schedule.optimizer_step(self._model, self._optimizer, self._grad_clip)
else:
output, label, loss = self._schedule.forward_backward_step(
data_iter, self._model, self._criterion, self._optimizer,
forward_only=True,
grad_accum_size=1,
return_loss=return_loss)
self.schedule.step()
# consume the remaining dataset left out due to gradient accumulation
if is_last_iteration:
while True:
try:
_ = next(data_iter)
except StopIteration:
break
return output, label, loss

View File

@ -0,0 +1,2 @@
from .grad_scaler import GradScaler
from .amp_type import AMP_TYPE

View File

@ -0,0 +1,577 @@
# modified from https://github.com/pytorch/pytorch/blob/master/torch/cuda/amp/grad_scaler.p
import torch
from collections import defaultdict, abc
import warnings
from enum import Enum
from typing import Any, Dict, List, Optional, Tuple
from colossalai.context import ParallelMode
import torch.distributed as dist
from colossalai.core import global_context as gpc
class _MultiDeviceReplicator(object):
"""
Lazily serves copies of a tensor to requested devices. Copies are cached per-device.
"""
def __init__(self, master_tensor: torch.Tensor) -> None:
assert master_tensor.is_cuda or master_tensor.device.type == 'xla'
self.master = master_tensor
self._per_device_tensors: Dict[torch.device, torch.Tensor] = {}
def get(self, device) -> torch.Tensor:
retval = self._per_device_tensors.get(device, None)
if retval is None:
retval = self.master.to(
device=device, non_blocking=True, copy=True)
self._per_device_tensors[device] = retval
return retval
# Defines default_factory for GradScaler's _per_optimizer_states defaultdict,
# as well as associated "enum" values. Prefers defining these at top level because
# - Lambdas can't be pickled, so we don't want to supply a lambda as the factory.
# - Defining READY, UNSCALED, STEPPED and _refresh_per_optimizer_state within GradScaler
# causes a circular reference, which we'd rather avoid.
class OptState(Enum):
READY = 0
UNSCALED = 1
STEPPED = 2
def _refresh_per_optimizer_state():
return {"stage": OptState.READY, "found_inf_per_device": {}}
class GradScaler(object):
_scale: Optional[torch.Tensor]
_grows_tracker: Optional[torch.Tensor]
_per_optimizer_states: Dict[int, Dict[str, Any]]
"""
An instance ``scaler`` of :class:`GradScaler` helps perform the steps of gradient scaling
conveniently.
* ``scaler.scale(loss)`` multiplies a given loss by ``scaler``'s current scale factor.
* ``scaler.step(optimizer)`` safely unscales gradients and calls ``optimizer.step()``.
* ``scaler.update()`` updates ``scaler``'s scale factor.
Example::
# Creates a GradScaler once at the beginning of training.
scaler = GradScaler()
for epoch in epochs:
for input, target in data:
optimizer.zero_grad()
output = model(input)
loss = loss_fn(output, target)
# Scales loss. Calls backward() on scaled loss to create scaled gradients.
scaler.scale(loss).backward()
# scaler.step() first unscales gradients of the optimizer's params.
# If gradients don't contain infs/NaNs, optimizer.step() is then called,
# otherwise, optimizer.step() is skipped.
scaler.step(optimizer)
# Updates the scale for next iteration.
scaler.update()
See the :ref:`Automatic Mixed Precision examples<amp-examples>` for usage
(along with autocasting) in more complex cases like gradient clipping, gradient accumulation, gradient penalty,
and multiple losses/optimizers.
``scaler`` dynamically estimates the scale factor each iteration. To minimize gradient underflow,
a large scale factor should be used. However, ``float16`` values can "overflow" (become inf or NaN) if
the scale factor is too large. Therefore, the optimal scale factor is the largest factor that can be used
without incurring inf or NaN gradient values.
``scaler`` approximates the optimal scale factor over time by checking the gradients for infs and NaNs during every
``scaler.step(optimizer)`` (or optional separate ``scaler.unscale_(optimizer)``, see :meth:`unscale_`).
* If infs/NaNs are found, ``scaler.step(optimizer)`` skips the underlying ``optimizer.step()`` (so the params
themselves remain uncorrupted) and ``update()`` multiplies the scale by ``backoff_factor``.
* If no infs/NaNs are found, ``scaler.step(optimizer)`` runs the underlying ``optimizer.step()`` as usual.
If ``growth_interval`` unskipped iterations occur consecutively, ``update()`` multiplies the scale by
``growth_factor``.
The scale factor often causes infs/NaNs to appear in gradients for the first few iterations as its
value calibrates. ``scaler.step`` will skip the underlying ``optimizer.step()`` for these
iterations. After that, step skipping should occur rarely (once every few hundred or thousand iterations).
Args:
init_scale (float, optional, default=2.**16): Initial scale factor.
growth_factor (float, optional, default=2.0): Factor by which the scale is multiplied during
:meth:`update` if no inf/NaN gradients occur for ``growth_interval`` consecutive iterations.
backoff_factor (float, optional, default=0.5): Factor by which the scale is multiplied during
:meth:`update` if inf/NaN gradients occur in an iteration.
growth_interval (int, optional, default=2000): Number of consecutive iterations without inf/NaN gradients
that must occur for the scale to be multiplied by ``growth_factor``.
enabled (bool, optional, default=True): If ``False``, disables gradient scaling. :meth:`step` simply
invokes the underlying ``optimizer.step()``, and other methods become no-ops.
"""
def __init__(self,
init_scale=2.**16,
growth_factor=2.0,
backoff_factor=0.5,
growth_interval=2000,
enabled=True):
if enabled and not torch.cuda.is_available():
warnings.warn(
"torch.cuda.amp.GradScaler is enabled, but CUDA is not available. Disabling.")
self._enabled = False
else:
self._enabled = enabled
if self._enabled:
assert growth_factor > 1.0, "The growth factor must be > 1.0."
assert backoff_factor < 1.0, "The backoff factor must be < 1.0."
self._init_scale = init_scale
# self._scale will be lazily initialized during the first call to scale()
self._scale = None
self._growth_factor = growth_factor
self._backoff_factor = backoff_factor
self._growth_interval = growth_interval
self._init_growth_tracker = 0
# self._growth_tracker will be lazily initialized during the first call to scale()
self._growth_tracker = None
self._per_optimizer_states = defaultdict(
_refresh_per_optimizer_state)
def _check_scale_growth_tracker(self, funcname) -> Tuple[torch.Tensor, torch.Tensor]:
fix = "This may indicate your script did not use scaler.scale(loss or outputs) earlier in the iteration."
assert self._scale is not None, "Attempted {} but _scale is None. ".format(
funcname) + fix
assert self._growth_tracker is not None, "Attempted {} but _growth_tracker is None. ".format(
funcname) + fix
return (self._scale, self._growth_tracker)
def _lazy_init_scale_growth_tracker(self, dev):
assert self._growth_tracker is None, "_growth_tracker initialized before _scale"
self._scale = torch.full(
(1,), self._init_scale, dtype=torch.float32, device=dev)
self._growth_tracker = torch.full(
(1,), self._init_growth_tracker, dtype=torch.int32, device=dev)
def scale(self, outputs):
"""
Multiplies ('scales') a tensor or list of tensors by the scale factor.
Returns scaled outputs. If this instance of :class:`GradScaler` is not enabled, outputs are returned
unmodified.
Args:
outputs (Tensor or iterable of Tensors): Outputs to scale.
"""
if not self._enabled:
return outputs
# Short-circuit for the common case.
if isinstance(outputs, torch.Tensor):
assert outputs.is_cuda or outputs.device.type == 'xla'
if self._scale is None:
self._lazy_init_scale_growth_tracker(outputs.device)
assert self._scale is not None
return outputs * self._scale.to(device=outputs.device, non_blocking=True)
# Invoke the more complex machinery only if we're treating multiple outputs.
# holds a reference that can be overwritten by apply_scale
stash: List[_MultiDeviceReplicator] = []
def apply_scale(val):
if isinstance(val, torch.Tensor):
assert val.is_cuda or val.device.type == 'xla'
if len(stash) == 0:
if self._scale is None:
self._lazy_init_scale_growth_tracker(val.device)
assert self._scale is not None
stash.append(_MultiDeviceReplicator(self._scale))
return val * stash[0].get(val.device)
elif isinstance(val, abc.Iterable):
iterable = map(apply_scale, val)
if isinstance(val, list) or isinstance(val, tuple):
return type(val)(iterable)
else:
return iterable
else:
raise ValueError(
"outputs must be a Tensor or an iterable of Tensors")
return apply_scale(outputs)
def _unscale_grads_(self, optimizer, inv_scale, found_inf, allow_fp16):
per_device_inv_scale = _MultiDeviceReplicator(inv_scale)
per_device_found_inf = _MultiDeviceReplicator(found_inf)
# To set up _amp_foreach_non_finite_check_and_unscale_, split grads by device and dtype.
# There could be hundreds of grads, so we'd like to iterate through them just once.
# However, we don't know their devices or dtypes in advance.
# https://stackoverflow.com/questions/5029934/defaultdict-of-defaultdict
# Google says mypy struggles with defaultdicts type annotations.
per_device_and_dtype_grads = defaultdict(
lambda: defaultdict(list)) # type: ignore[var-annotated]
with torch.no_grad():
for group in optimizer.param_groups:
for param in group["params"]:
if param.grad is None:
continue
if (not allow_fp16) and param.grad.dtype == torch.float16:
raise ValueError(
"Attempting to unscale FP16 gradients.")
if param.grad.is_sparse:
# is_coalesced() == False means the sparse grad has values with duplicate indices.
# coalesce() deduplicates indices and adds all values that have the same index.
# For scaled fp16 values, there's a good chance coalescing will cause overflow,
# so we should check the coalesced _values().
if param.grad.dtype is torch.float16:
param.grad = param.grad.coalesce()
to_unscale = param.grad._values()
else:
to_unscale = param.grad
# TODO: is there a way to split by device and dtype without appending in the inner loop?
per_device_and_dtype_grads[to_unscale.device][to_unscale.dtype].append(
to_unscale)
for device, per_dtype_grads in per_device_and_dtype_grads.items():
for grads in per_dtype_grads.values():
torch._amp_foreach_non_finite_check_and_unscale_(grads,
per_device_found_inf.get(
device),
per_device_inv_scale.get(device))
# For tensor parallel paramters it should be all-reduced over tensor parallel process group
if gpc.is_initialized(ParallelMode.TENSOR) and gpc.get_world_size(ParallelMode.TENSOR) > 1:
for tensor in per_device_found_inf._per_device_tensors.values():
dist.all_reduce(tensor, op=dist.ReduceOp.MAX,
group=gpc.get_group(ParallelMode.TENSOR))
return per_device_found_inf._per_device_tensors
def unscale_(self, optimizer):
"""
Divides ("unscales") the optimizer's gradient tensors by the scale factor.
:meth:`unscale_` is optional, serving cases where you need to
:ref:`modify or inspect gradients<working-with-unscaled-gradients>`
between the backward pass(es) and :meth:`step`.
If :meth:`unscale_` is not called explicitly, gradients will be unscaled automatically during :meth:`step`.
Simple example, using :meth:`unscale_` to enable clipping of unscaled gradients::
...
scaler.scale(loss).backward()
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
scaler.step(optimizer)
scaler.update()
Args:
optimizer (torch.optim.Optimizer): Optimizer that owns the gradients to be unscaled.
.. note::
:meth:`unscale_` does not incur a CPU-GPU sync.
.. warning::
:meth:`unscale_` should only be called once per optimizer per :meth:`step` call,
and only after all gradients for that optimizer's assigned parameters have been accumulated.
Calling :meth:`unscale_` twice for a given optimizer between each :meth:`step` triggers a RuntimeError.
.. warning::
:meth:`unscale_` may unscale sparse gradients out of place, replacing the ``.grad`` attribute.
"""
if not self._enabled:
return
self._check_scale_growth_tracker("unscale_")
optimizer_state = self._per_optimizer_states[id(optimizer)]
if optimizer_state["stage"] is OptState.UNSCALED:
raise RuntimeError(
"unscale_() has already been called on this optimizer since the last update().")
elif optimizer_state["stage"] is OptState.STEPPED:
raise RuntimeError("unscale_() is being called after step().")
# FP32 division can be imprecise for certain compile options, so we carry out the reciprocal in FP64.
assert self._scale is not None
inv_scale = self._scale.double().reciprocal().float()
found_inf = torch.full(
(1,), 0.0, dtype=torch.float32, device=self._scale.device)
optimizer_state["found_inf_per_device"] = self._unscale_grads_(
optimizer, inv_scale, found_inf, False)
optimizer_state["stage"] = OptState.UNSCALED
def _maybe_opt_step(self, optimizer, optimizer_state, *args, **kwargs):
retval = None
if not sum(v.item() for v in optimizer_state["found_inf_per_device"].values()):
retval = optimizer.step(*args, **kwargs)
return retval
def step(self, optimizer, *args, **kwargs):
"""
:meth:`step` carries out the following two operations:
1. Internally invokes ``unscale_(optimizer)`` (unless :meth:`unscale_` was explicitly called for ``optimizer``
earlier in the iteration). As part of the :meth:`unscale_`, gradients are checked for infs/NaNs.
2. If no inf/NaN gradients are found, invokes ``optimizer.step()`` using the unscaled
gradients. Otherwise, ``optimizer.step()`` is skipped to avoid corrupting the params.
``*args`` and ``**kwargs`` are forwarded to ``optimizer.step()``.
Returns the return value of ``optimizer.step(*args, **kwargs)``.
Args:
optimizer (torch.optim.Optimizer): Optimizer that applies the gradients.
args: Any arguments.
kwargs: Any keyword arguments.
.. warning::
Closure use is not currently supported.
"""
if (not self._enabled):
return optimizer.step(*args, **kwargs)
if "closure" in kwargs:
raise RuntimeError(
"Closure use is not currently supported if GradScaler is enabled.")
self._check_scale_growth_tracker("step")
optimizer_state = self._per_optimizer_states[id(optimizer)]
if optimizer_state["stage"] is OptState.STEPPED:
raise RuntimeError(
"step() has already been called since the last update().")
retval = None
if (hasattr(optimizer, "_step_supports_amp_scaling") and optimizer._step_supports_amp_scaling):
# This optimizer has customized scale-handling logic, so we can call optimizer.step() directly.
# The contract with custom optimizers is that their step() should accept an additional,
# optional grad_scaler kwarg. We append self to the kwargs so the custom optimizer has full information:
# it can query its own state, invoke unscale_ on itself, etc
retval = optimizer.step(*args, **dict(kwargs, grad_scaler=self))
optimizer_state["stage"] = OptState.STEPPED
return retval
if optimizer_state["stage"] is OptState.READY:
self.unscale_(optimizer)
assert len(optimizer_state["found_inf_per_device"]
) > 0, "No inf checks were recorded for this optimizer."
retval = self._maybe_opt_step(
optimizer, optimizer_state, *args, **kwargs)
optimizer_state["stage"] = OptState.STEPPED
return retval
def update(self, new_scale=None):
"""
Updates the scale factor.
If any optimizer steps were skipped the scale is multiplied by ``backoff_factor``
to reduce it. If ``growth_interval`` unskipped iterations occurred consecutively,
the scale is multiplied by ``growth_factor`` to increase it.
Passing ``new_scale`` sets the new scale value manually. (``new_scale`` is not
used directly, it's used to fill GradScaler's internal scale tensor. So if
``new_scale`` was a tensor, later in-place changes to that tensor will not further
affect the scale GradScaler uses internally.)
Args:
new_scale (float or :class:`torch.cuda.FloatTensor`, optional, default=None): New scale factor.
.. warning::
:meth:`update` should only be called at the end of the iteration, after ``scaler.step(optimizer)`` has
been invoked for all optimizers used this iteration.
"""
if not self._enabled:
return
_scale, _growth_tracker = self._check_scale_growth_tracker("update")
if new_scale is not None:
# Accept a new user-defined scale.
if isinstance(new_scale, float):
self._scale.fill_(new_scale) # type: ignore[union-attr]
else:
reason = "new_scale should be a float or a 1-element torch.cuda.FloatTensor with requires_grad=False."
# type: ignore[attr-defined]
assert isinstance(new_scale, torch.cuda.FloatTensor), reason
assert new_scale.numel() == 1, reason
assert new_scale.requires_grad is False, reason
self._scale.copy_(new_scale) # type: ignore[union-attr]
else:
# Consume shared inf/nan data collected from optimizers to update the scale.
# If all found_inf tensors are on the same device as self._scale, this operation is asynchronous.
found_infs = [found_inf.to(device=_scale.device, non_blocking=True)
for state in self._per_optimizer_states.values()
for found_inf in state["found_inf_per_device"].values()]
assert len(
found_infs) > 0, "No inf checks were recorded prior to update."
found_inf_combined = found_infs[0]
if len(found_infs) > 1:
for i in range(1, len(found_infs)):
found_inf_combined += found_infs[i]
torch._amp_update_scale_(_scale,
_growth_tracker,
found_inf_combined,
self._growth_factor,
self._backoff_factor,
self._growth_interval)
# To prepare for next iteration, clear the data collected from optimizers this iteration.
self._per_optimizer_states = defaultdict(_refresh_per_optimizer_state)
def _get_scale_async(self):
return self._scale
def get_scale(self):
"""
Returns a Python float containing the current scale, or 1.0 if scaling is disabled.
.. warning::
:meth:`get_scale` incurs a CPU-GPU sync.
"""
if self._enabled:
return self._init_scale if self._scale is None else self._get_scale_async().item()
else:
return 1.0
def get_growth_factor(self):
r"""
Returns a Python float containing the scale growth factor.
"""
return self._growth_factor
def set_growth_factor(self, new_factor):
r"""
Args:
new_scale (float): Value to use as the new scale growth factor.
"""
self._growth_factor = new_factor
def get_backoff_factor(self):
r"""
Returns a Python float containing the scale backoff factor.
"""
return self._backoff_factor
def set_backoff_factor(self, new_factor):
r"""
Args:
new_scale (float): Value to use as the new scale backoff factor.
"""
self._backoff_factor = new_factor
def get_growth_interval(self):
r"""
Returns a Python int containing the growth interval.
"""
return self._growth_interval
def set_growth_interval(self, new_interval):
r"""
Args:
new_interval (int): Value to use as the new growth interval.
"""
self._growth_interval = new_interval
def _get_growth_tracker(self):
if self._enabled:
return self._init_growth_tracker if self._growth_tracker is None else self._growth_tracker.item()
else:
return 0
def is_enabled(self):
r"""
Returns a bool indicating whether this instance is enabled.
"""
return self._enabled
def state_dict(self):
r"""
Returns the state of the scaler as a :class:`dict`. It contains five entries:
* ``"scale"`` - a Python float containing the current scale
* ``"growth_factor"`` - a Python float containing the current growth factor
* ``"backoff_factor"`` - a Python float containing the current backoff factor
* ``"growth_interval"`` - a Python int containing the current growth interval
* ``"_growth_tracker"`` - a Python int containing the number of recent consecutive unskipped steps.
If this instance is not enabled, returns an empty dict.
.. note::
If you wish to checkpoint the scaler's state after a particular iteration, :meth:`state_dict`
should be called after :meth:`update`.
"""
return {"scale": self.get_scale(),
"growth_factor": self._growth_factor,
"backoff_factor": self._backoff_factor,
"growth_interval": self._growth_interval,
"_growth_tracker": self._get_growth_tracker()} if self._enabled else {}
def load_state_dict(self, state_dict):
r"""
Loads the scaler state. If this instance is disabled, :meth:`load_state_dict` is a no-op.
Args:
state_dict(dict): scaler state. Should be an object returned from a call to :meth:`state_dict`.
"""
if not self._enabled:
return
if len(state_dict) == 0:
raise RuntimeError("The source state dict is empty, possibly because it was saved "
"from a disabled instance of GradScaler.")
self._init_scale = state_dict["scale"]
if self._scale is not None:
self._scale.fill_(state_dict["scale"])
self._growth_factor = state_dict["growth_factor"]
self._backoff_factor = state_dict["backoff_factor"]
self._growth_interval = state_dict["growth_interval"]
self._init_growth_tracker = state_dict["_growth_tracker"]
if self._growth_tracker is not None:
self._growth_tracker.fill_(state_dict["_growth_tracker"])
def __getstate__(self):
state = self.__dict__.copy()
if self._enabled:
assert len(self._per_optimizer_states) == 0, "A GradScaler instance may only be pickled at the beginning "\
"of an iteration, or at the end after scaler.update()."
# Pickling _scale and _growth_tracker Tensors directly triggers
# "warnings.warn("pickle support for Storage will be removed in 1.5..."
# so instead, we set the unpickled instance up to reinitialize them lazily.
state['_init_scale'] = self.get_scale()
state['_init_growth_tracker'] = self._get_growth_tracker()
state['_scale'] = None
state['_growth_tracker'] = None
return state
def __setstate__(self, state):
self.__dict__.update(state)
def _check_inf_per_device(self, optimizer):
_scale, _ = self._check_scale_growth_tracker("_check_inf_per_device")
dummy_inv_scale = torch.full(
(1,), 1.0, dtype=torch.float32, device=_scale.device)
found_inf = torch.full(
(1,), 0.0, dtype=torch.float32, device=_scale.device)
self._per_optimizer_states[id(optimizer)]["found_inf_per_device"] = \
self._unscale_grads_(optimizer, dummy_inv_scale, found_inf, True)
return self._per_optimizer_states[id(optimizer)]["found_inf_per_device"]
def _found_inf_per_device(self, optimizer):
return self._per_optimizer_states[id(optimizer)]["found_inf_per_device"]

View File

@ -5,125 +5,85 @@ from abc import ABC, abstractmethod
import torch
from colossalai.core import global_context as gpc
from colossalai.logging import get_global_dist_logger
from colossalai.utils import get_current_device
class BaseSchedule(ABC):
"""A basic helper class to control the process of training or evaluation.
It mainly composes of forward_backward_step for gradient backward and
optimizer_step for parameters update.
For the convenience to enable FP16, we aggreate all codes that contain the
control of FP16 in class schedule.
"""
def __init__(self):
self.initialized = False
self.logger = get_global_dist_logger()
@property
@abstractmethod
def num_steps(self):
"""The number of batches in training or evaluation.
"""
pass
def initialize(self,
dataloader=None,
model=None,
criterion=None,
optimizer=None,
lr_scheduler=None):
"""Initializes the schedule and set parameters before running.
:param dataloader: DataLoader in training or evaluation
:param model: The neural network model
:param criterion: Criterion for calculating loss
:param optimizer: Optimizer for updating the parameters
:param lr_scheduler: Learning rate scheduler in the process
"""
self.dataloader = dataloader
assert model is not None, "Schedule requires a model"
self.model = model
assert criterion is not None, "Schedule requires a criterion"
self.criterion = criterion
assert optimizer is not None, "Schedule requires an optimizer"
self.optimizer = optimizer
self.lr_scheduler = lr_scheduler
self.initialized = True
def check_initialized(self):
"""Checks whether the schedule is initialized.
"""
assert self.initialized, \
'Schedule is not initialized. Call schedule.initialize(...) before using it.'
def load_batch(self):
"""Loads a batch of dataset. It returns the data and labels which are
already in the same GPU as where the model's.
:return: (data, label)
:rtype: (Tensor, Tensor)
"""
self.check_initialized()
if self.data_iter is None:
raise RuntimeError('Dataloader is not defined.')
data, label = next(self.data_iter)
return self._move_to_device(data), self._move_to_device(label)
@staticmethod
def _move_tensor(element):
if torch.is_tensor(element):
if not element.is_cuda:
return element.to(get_current_device()).detach()
return element
def _move_to_device(self, data):
if isinstance(data, (
tuple,
list,
)):
data = tuple([
d.to(get_current_device()).detach() for d in data
if torch.is_tensor(d)
])
if isinstance(data, (tuple, list)):
data = tuple([self._move_tensor(d) for d in data])
elif torch.is_tensor(data):
data = data.to(get_current_device()).detach()
return data
def train(self, dataloader=None, mode=True):
"""Sets the dataloader to be used and turn the model to
training or evaluation mode.
def load_batch(self, data_iter):
"""Loads a batch from data iterator. It returns the data and labels which are
already in the same GPU as where the model's.
:param dataloader: Dataloader to be used
:param mode: If True, the model will set as training mode. Otherwise, evaluation mode.
:return: (data, label)
:rtype: (Tensor, Tensor)
"""
self.check_initialized()
if mode:
self.model.train()
else:
self.model.eval()
if dataloader is not None:
self.dataloader = dataloader
self.data_iter = iter(dataloader)
if data_iter is None:
raise RuntimeError('Dataloader is not defined.')
data, label = next(data_iter)
return self._move_to_device(data), self._move_to_device(label)
def zero_grad(self, forward_only=False):
"""Cleans gradients with the optimizer.
"""
if not forward_only:
self.check_initialized()
self.optimizer.zero_grad()
def initialize(self, model, optimizer):
"""Initializes the model and the optimizer before training.
This is often used in FP16 training.
def get_lr(self):
"""Returns the current learning rate.
:param model: The neural network model
:param optimizer: Optimizer for updating the parameters
"""
if self.lr_scheduler is not None:
return self.lr_scheduler.get_lr()[0]
else:
return self.optimizer.param_groups[0]['lr']
def step(self):
"""Updates the parameters and learning rate with the optimizer.
"""
self.check_initialized()
self.optimizer.step()
# update lr scheduler
if self.lr_scheduler is not None:
self.lr_scheduler.step()
return model, optimizer
@abstractmethod
def forward_backward_step(self, forward_only=False, return_loss=True):
def forward_backward_step(self,
data_iter,
model,
criterion,
optimizer=None,
forward_only=False,
grad_accum_size: int = 1,
return_loss=True):
"""The process function over a batch of dataset for training or evaluation.
:param forward_only: If True, the process won't include backward.
:param return_loss: If False, the loss won't be returned.
:param data_iter: Data iterator of the dataset
:param model: Model used in training or evaluation
:param optimizer: Optimizer used in training or evaluation
:param criterion: Loss function
:param forward_only: If True, the process won't include backward
:param grad_accum_size: Steps of gradient accumulation
:param return_loss: If False, the loss won't be returned
"""
pass
@abstractmethod
def optimizer_step(self, model, optimizer, grad_clipping: float = 0.0):
"""Updates the parameters with the optimizer.
:param model: The neural network model
:param optimizer: Optimizer for updating the parameters
:param grad_clipping: The norm of gradient clipping
:type grad_clipping: float, optional
"""
pass

View File

@ -4,19 +4,24 @@
try:
import apex.amp as apex_amp
except:
print('apex is required for mixed precision training')
pass
try:
import torch.cuda.amp as torch_amp
except:
print('PyTorch amp is not supported with the current PyTorch version')
pass
from typing import Iterable
import torch.nn as nn
from torch.optim import Optimizer
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.engine.amp_type import AMP_TYPE
from colossalai.nn import (ZeroRedundancyOptimizer_Level_2,
ZeroRedundancyOptimizer_Level_3)
from ._utils import convert_to_fp16
from colossalai.nn.optimizer._utils import clip_grad_norm_fp32
from ._base_schedule import BaseSchedule
from ._utils import convert_to_fp16, convert_to_fp32
from ..amp import AMP_TYPE, GradScaler
class NoPipelineSchedule(BaseSchedule):
@ -30,6 +35,7 @@ class NoPipelineSchedule(BaseSchedule):
:type amp_type: AMP_TYPE
:type amp_config: dict
"""
def __init__(
self,
amp_type: AMP_TYPE = None,
@ -41,12 +47,6 @@ class NoPipelineSchedule(BaseSchedule):
assert amp_type is None or isinstance(amp_type, AMP_TYPE), \
'unrecognised value for argument fp16, it can only be None, torch or apex'
# LSG: check compatibility
# LSG: torch.cuda.amp and apex.amp cannot be used for tensor parallel
if gpc.is_initialized(ParallelMode.TENSOR) and gpc.get_world_size(
ParallelMode.TENSOR) > 1:
assert amp_type != AMP_TYPE.TORCH and amp_type != AMP_TYPE.APEX, \
'You can only AMP_TYPE.PARALLEL for tensor parallel training'
self.use_zero_level_2_3 = False
if amp_type is not None:
@ -79,107 +79,110 @@ class NoPipelineSchedule(BaseSchedule):
self.fp16 = False
self.amp_type = None
@property
def num_steps(self):
return len(self.dataloader)
def initialize(self,
dataloader,
model,
criterion,
optimizer,
lr_scheduler=None):
super().initialize(dataloader,
model,
criterion,
optimizer,
lr_scheduler=lr_scheduler)
if isinstance(self.optimizer, (ZeroRedundancyOptimizer_Level_2,
ZeroRedundancyOptimizer_Level_3)):
def initialize(self, model: nn.Module, optimizer: Optimizer):
if isinstance(optimizer, (ZeroRedundancyOptimizer_Level_2,
ZeroRedundancyOptimizer_Level_3)):
self.use_zero_level_2_3 = True
assert self.amp_type != AMP_TYPE.PARALLEL, 'ZeRO Level 2 and 3 are mutually exclusive with AMP_TYPE.PARALLEL'
assert self.amp_type != AMP_TYPE.PARALLEL, \
'ZeRO Level 2 and 3 are mutually exclusive with AMP_TYPE.PARALLEL'
if self.fp16:
if self.amp_type == AMP_TYPE.TORCH:
self._torch_amp_scaler = torch_amp.GradScaler(**self.amp_cfg)
self._torch_amp_scaler = GradScaler(**self.amp_cfg)
elif self.amp_type == AMP_TYPE.APEX:
self.model, self.optimizer = apex_amp.initialize(
self.model, self.optimizer, **self.amp_cfg)
model, optimizer = apex_amp.initialize(model, optimizer, **self.amp_cfg)
def forward_backward_step(self, forward_only=False, return_loss=True):
return model, optimizer
def forward_backward_step(self,
data_iter: Iterable,
model: nn.Module,
criterion: nn.modules.loss._Loss,
optimizer: Optimizer = None,
forward_only: bool = False,
grad_accum_size: int = 1,
return_loss: bool = True):
"""The process function that loads loads a batch of dataset and feeds it to the model.
The returned labels and loss will None if :attr:`return_loss` is False.
:param data_iter: Data iterator of the dataloader, e.g. iter(dataloader)
:param model: Model for training and inference
:param criterion: Loss function for training
:param optimizer: Optimizer used for training
:param forward_only: If True, the model is run for the forward pass, else back propagation will be executed
:param grad_accum_size: The number of iterations for gradient accumulation
:param return_loss: Loss will be returned if True
:type data_iter: Iterator
:type model: torch.nn.Module
:type criterion: torch.nn.modules.loss._Loss
:type optimizer: torch.optim.Optimizer
:type forward_only: bool, optional
:type grad_accum_size: int
:type return_loss: bool, optional
:return: (output, label, loss)
"""
assert forward_only or return_loss, \
'The argument \'return_loss\' has to be True when \'forward_only\' is False, but got False.'
data, label = self.load_batch()
data, label = self.load_batch(data_iter)
loss = None
# LSG: leave for debug, make sure dataloader is deterministic
# if forward_only:
# img = data[0]
# rank = gpc.get_local_rank(ParallelMode.DATA)
# world_size = gpc.get_world_size(ParallelMode.DATA)
# group = gpc.get_group(ParallelMode.DATA)
# input_list = [img.clone() for _ in range(world_size)]
# output_list = [torch.empty_like(img) for _ in range(world_size)]
# output_list[rank] = img.clone()
# dist.all_to_all(output_tensor_list=output_list, input_tensor_list=input_list, group=group)
# assert torch.equal(output_list[0], output_list[1]) # and torch.equal(output_list[1], output_list[2])
# forward
if self.fp16 and self.amp_type == AMP_TYPE.TORCH:
with torch_amp.autocast():
output = self.model(*data)
output = model(*data)
if not isinstance(output, (tuple, list)):
output = (output,)
if return_loss:
loss = self.criterion(*output, *label)
loss = criterion(*output, *label)
else:
if self.use_zero_level_2_3 or self.amp_type == AMP_TYPE.PARALLEL:
data = convert_to_fp16(data)
output = self.model(*data)
output = model(*data)
if self.use_zero_level_2_3 or self.amp_type == AMP_TYPE.PARALLEL:
output = convert_to_fp32(output)
if not isinstance(output, (tuple, list)):
output = (output,)
if return_loss:
loss = self.criterion(*output, *label)
loss = criterion(*output, *label)
loss /= grad_accum_size
if not forward_only:
# backward
if self.use_zero_level_2_3:
self.optimizer.backward(loss)
optimizer.backward(loss)
elif self.fp16:
if self.amp_type == AMP_TYPE.APEX:
with apex_amp.scale_loss(loss,
self.optimizer) as scaled_loss:
with apex_amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
elif self.amp_type == AMP_TYPE.TORCH:
self._torch_amp_scaler.scale(loss).backward()
elif self.amp_type == AMP_TYPE.PARALLEL:
loss = self.optimizer.scale_loss(loss)
loss = optimizer.scale_loss(loss)
loss.backward()
# scale back to display the original value in logs
loss.div_(self.optimizer.grad_scaler.scale)
loss.div_(optimizer.grad_scaler.scale)
else:
loss.backward()
if return_loss:
return output, label, loss
return output, label, loss * grad_accum_size
else:
return output, None, None
def step(self):
def optimizer_step(self, model: nn.Module, optimizer: Optimizer, grad_clipping: float = 0.0):
# step optimizer
if self.fp16 and self.amp_type == AMP_TYPE.TORCH:
self._torch_amp_scaler.step(self.optimizer)
if grad_clipping > 0.0:
self._torch_amp_scaler.unscale_(optimizer)
clip_grad_norm_fp32(model.parameters(), grad_clipping)
self._torch_amp_scaler.step(optimizer)
self._torch_amp_scaler.update()
else:
self.optimizer.step()
# update lr scheduler
if self.lr_scheduler is not None:
self.lr_scheduler.step()
if not self.fp16 and not self.use_zero_level_2_3 and grad_clipping > 0.0:
clip_grad_norm_fp32(model.parameters(), grad_clipping)
optimizer.step()

View File

@ -15,7 +15,7 @@ from colossalai.nn import (ZeroRedundancyOptimizer_Level_2,
from colossalai.utils import get_current_device
from ._base_schedule import BaseSchedule
from ._utils import convert_to_fp16
from ..amp_type import AMP_TYPE
from ..amp import AMP_TYPE
def squeeze(x: Union[Tensor, tuple, list]):
@ -93,12 +93,11 @@ class PipelineSchedule(BaseSchedule):
)
# Pipeline schedule just puts data in memory
def load_batch(self):
self.check_initialized()
if self.data_iter is None:
def load_batch(self, data_iter):
if data_iter is None:
raise RuntimeError('Dataloader is not defined.')
self.batch_pos = 0
data, label = next(self.data_iter)
data, label = next(data_iter)
self.batch_data, self.batch_label = \
self._move_to_device(data), self._move_to_device(label)
batch_size = self.batch_data.shape[0]
@ -117,23 +116,8 @@ class PipelineSchedule(BaseSchedule):
self.batch_pos += self.microbatch_size
return (data,), (label,)
@property
def num_steps(self):
return len(self.dataloader)
def initialize(self,
dataloader,
model,
criterion,
optimizer,
lr_scheduler=None):
super().initialize(dataloader,
model,
criterion,
optimizer,
lr_scheduler=lr_scheduler)
if isinstance(self.optimizer, (ZeroRedundancyOptimizer_Level_2,
ZeroRedundancyOptimizer_Level_3)):
def initialize(self, model, optimizer):
if isinstance(optimizer, (ZeroRedundancyOptimizer_Level_2, ZeroRedundancyOptimizer_Level_3)):
raise TypeError(
"Pipeline schedule is currently not compatible with ZeRO Level 2 and Level 3"
)
@ -145,7 +129,8 @@ class PipelineSchedule(BaseSchedule):
'default tensor dtype is set to torch.half for fp16 training',
ranks=[0])
def forward_step(self, input_tensor, return_tensors, return_loss=True):
def forward_step(self, model, criterion, input_tensor, return_tensors,
grad_accum_size, return_loss=True):
"""Forward step for passed-in model. If it is the first stage, the input tensor
is obtained from data_iterator, otherwise the passed-in input_tensor is used.
Returns output tensor. This is a helper function and can be ignored by users.
@ -156,14 +141,14 @@ class PipelineSchedule(BaseSchedule):
if self.amp_type == AMP_TYPE.PARALLEL:
input_tensor = convert_to_fp16(input_tensor)
input_tensor = squeeze(input_tensor)
output_tensor = self.model(input_tensor)
output_tensor = model(input_tensor)
output_tensor = squeeze(output_tensor)
if gpc.is_last_rank(ParallelMode.PIPELINE):
if return_loss:
input_tensor, label = self.load_micro_batch()
loss_reduced = self.criterion(output_tensor, *
label) / self.num_microbatches
loss_reduced = criterion(output_tensor, *label) \
/ (self.num_microbatches * grad_accum_size)
return_tensors.append(
tuple((output_tensor, label[0], loss_reduced)))
return loss_reduced
@ -174,7 +159,7 @@ class PipelineSchedule(BaseSchedule):
else:
return output_tensor
def backward_step(self, input_tensor, output_tensor, output_tensor_grad):
def backward_step(self, optimizer, input_tensor, output_tensor, output_tensor_grad):
"""Backward step through the passed-in output tensor. If it is the last stage, the
output_tensor_grad is None, otherwise it is the gradients with respect to stage's output tensor.
Returns the gradients with respect to the input tensor (None if first stage).
@ -187,7 +172,7 @@ class PipelineSchedule(BaseSchedule):
# Backward pass.
if output_tensor_grad is None and self.amp_type == AMP_TYPE.PARALLEL:
output_tensor = self.optimizer.scale_loss(output_tensor)
output_tensor = optimizer.scale_loss(output_tensor)
torch.autograd.backward(output_tensor, grad_tensors=output_tensor_grad)
# Collect the grad of the input_tensor.
@ -197,17 +182,24 @@ class PipelineSchedule(BaseSchedule):
return input_tensor_grad
def forward_backward_step(self, forward_only=True, return_loss=True):
def forward_backward_step(self,
data_iter,
model,
criterion,
optimizer=None,
forward_only=False,
grad_accum_size: int = 1,
return_loss=True):
"""Runs non-interleaved 1F1B schedule, with communication between pipeline stages.
Returns a tuple with losses if the last stage, an empty tuple otherwise.
:return: (output, label, loss)
"""
assert forward_only or return_loss, \
'The argument \'return_loss\' has to be True when \'forward_only\' is False, but got False.'
self.load_batch()
self.load_batch(data_iter)
num_warmup_microbatches = \
(gpc.get_world_size(ParallelMode.PIPELINE) -
gpc.get_local_rank(ParallelMode.PIPELINE) - 1)
@ -233,9 +225,11 @@ class PipelineSchedule(BaseSchedule):
if not gpc.is_first_rank(ParallelMode.PIPELINE):
ft_shape = recv_tensor_meta(ft_shape)
input_tensor = recv_forward(ft_shape)
output_tensor = self.forward_step(input_tensor,
return_tensors,
return_loss=return_loss)
output_tensor = self.forward_step(
model, criterion,
input_tensor, return_tensors,
grad_accum_size, return_loss=return_loss
)
if not gpc.is_last_rank(ParallelMode.PIPELINE):
bt_shape = output_tensor.shape
fs_checker = send_tensor_meta(output_tensor, fs_checker)
@ -257,9 +251,11 @@ class PipelineSchedule(BaseSchedule):
for i in range(num_microbatches_remaining):
last_iteration = (i == (num_microbatches_remaining - 1))
output_tensor = self.forward_step(input_tensor,
return_tensors,
return_loss=return_loss)
output_tensor = self.forward_step(
model, criterion,
input_tensor, return_tensors,
grad_accum_size, return_loss=return_loss
)
if forward_only:
send_forward(output_tensor)
@ -279,9 +275,11 @@ class PipelineSchedule(BaseSchedule):
input_tensor = input_tensors.pop(0)
output_tensor = output_tensors.pop(0)
input_tensor_grad = self.backward_step(input_tensor,
output_tensor,
output_tensor_grad)
input_tensor_grad = self.backward_step(
optimizer,
input_tensor, output_tensor,
output_tensor_grad
)
if last_iteration:
input_tensor = None
@ -298,9 +296,11 @@ class PipelineSchedule(BaseSchedule):
output_tensor_grad = recv_backward(bt_shape)
input_tensor_grad = self.backward_step(input_tensor,
output_tensor,
output_tensor_grad)
input_tensor_grad = self.backward_step(
optimizer,
input_tensor, output_tensor,
output_tensor_grad
)
send_backward(input_tensor_grad)
@ -309,8 +309,11 @@ class PipelineSchedule(BaseSchedule):
output, label, loss = tuple(map(list, zip(*return_tensors)))
return (torch.cat(output, dim=0),
torch.cat(label, dim=0),
sum(loss))
sum(loss) * grad_accum_size)
else:
return tuple((torch.cat(return_tensors, dim=0), None, None))
else:
return tuple((None, None, None))
def optimizer_step(self, model, optimizer, grad_clipping: float = 0.0):
optimizer.step()

View File

@ -14,3 +14,14 @@ def convert_to_fp16(data: Union[Tensor, List[Tensor]]):
else:
raise TypeError(f"Expected argument 'data' to be a Tensor or a list/tuple of Tensor, but got {type(data)}")
return ret
def convert_to_fp32(data: Union[Tensor, List[Tensor]]):
if isinstance(data, Tensor):
ret = data.float()
elif isinstance(data, (list, tuple)):
ret = [val.float() for val in data]
else:
raise TypeError(f"Expected argument 'data' to be a Tensor or a list/tuple of Tensor, but got {type(data)}")
return ret

View File

@ -6,18 +6,20 @@ import pprint
import random
from pathlib import Path
from typing import Callable, Iterable, Optional, Union
from typing import Tuple
import numpy as np
import torch
from torch.utils.data import DataLoader
from colossalai.engine import AMP_TYPE, NoPipelineSchedule, PipelineSchedule
from colossalai.engine import Engine
from colossalai.logging import get_global_dist_logger, init_global_dist_logger
from colossalai.nn import DataParallelSampler
from colossalai.nn.model.base_model import BaseModel
from .builder import (ModelInitializer, build_dataset, build_loss,
build_lr_scheduler, build_model, build_optimizer,
build_optimizer_wrapper)
build_model, build_optimizer,
build_optimizer_wrapper, build_schedule)
from .context import Config, ParallelMode
from .core import global_context as gpc
from .utils import get_current_device, sync_model_param_in_dp
@ -182,7 +184,7 @@ def initialize(config: Union[str, dict] = None,
backend: str = None,
train_dataloader: Optional[Union[Iterable, Callable]] = None,
test_dataloader: Optional[Union[Iterable, Callable]] = None,
):
) -> Tuple[Engine, DataLoader, DataLoader]:
'''Core function that initializes distributed environment, logger, cudnn, data, model, loss function, optimizer, and lr_scheduler(their configs are in gpc.config).
:param config: config file or config file path are both acceptable
@ -201,7 +203,7 @@ def initialize(config: Union[str, dict] = None,
:type train_dataloader: Optional[Union[Iterable, Callable]], optional
:param test_dataloader: If None, the config is used to build a dataloder; Else, it should be a dataloader object or a function with no arguments which can build a dataloader, defaults to None
:type test_dataloader: Optional[Union[Iterable, Callable]], optional
:return: (model, train_dataloader, test_dataloader, criterion, optimizer, schedule, lr_scheduler)
:return: (engine, train_dataloader, test_dataloader, criterion)
:rtype: tuple
'''
# initialize distributed environment
@ -337,21 +339,7 @@ def initialize(config: Union[str, dict] = None,
optimizer = build_optimizer_wrapper(fp16_cfg, optimizer)
logger.info('Optimizer is created', ranks=[0])
lr_scheduler = None
if hasattr(gpc.config, 'lr_scheduler'):
if hasattr(gpc.config, 'num_steps'):
total_steps = gpc.config.num_steps
elif hasattr(gpc.config, 'num_epochs'):
total_steps = int(gpc.config.num_epochs * len(train_dataloader))
else:
raise Exception(
'Please specify training stopping criterion num_steps or num_epochs in your configuration.'
)
lr_scheduler = build_lr_scheduler(gpc.config.lr_scheduler, optimizer,
total_steps, len(train_dataloader))
logger.info('Learning rate scheduler is created', ranks=[0])
# pipeline or no pipeline schedule
# build schedule and engine
if hasattr(gpc.config, 'fp16'):
amp_type = gpc.config.fp16.mode
amp_cfg = gpc.config.fp16.copy()
@ -360,12 +348,32 @@ def initialize(config: Union[str, dict] = None,
amp_type = None
amp_cfg = None
if gpc.is_initialized(ParallelMode.PIPELINE) and gpc.get_world_size(ParallelMode.PIPELINE) > 1:
assert hasattr(gpc.config,
'schedule'), "Config 'schedule' not found in your configuration file for pipeline parallel training"
engine_cfg = gpc.config.get('engine', dict())
schedule_cfg = engine_cfg.pop('schedule', None)
schedule_type = None
if schedule_cfg is not None:
schedule_type = schedule_cfg.get('type', None)
if schedule_type is not None:
# run customized schedule
schedule_cfg['amp_type'] = amp_type
schedule_cfg['amp_config'] = amp_cfg
schedule = build_schedule(schedule_cfg)
elif gpc.is_initialized(ParallelMode.PIPELINE) and gpc.get_world_size(ParallelMode.PIPELINE) > 1:
assert schedule_cfg is not None, \
"Config 'engine.schedule' not found in your configuration file for pipeline parallel training"
schedule = PipelineSchedule(
amp_type=amp_type, amp_config=amp_cfg, **gpc.config.schedule.copy())
amp_type=amp_type, amp_config=amp_cfg, **schedule_cfg.copy())
else:
schedule = NoPipelineSchedule(amp_type=amp_type, amp_config=amp_cfg)
return model, train_dataloader, test_dataloader, criterion, optimizer, schedule, lr_scheduler
engine = Engine(
model=model,
optimizer=optimizer,
criterion=criterion,
step_schedule=schedule,
**gpc.config.get('engine', dict())
)
return engine, train_dataloader, test_dataloader

View File

@ -7,6 +7,7 @@ from torch import Tensor
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.utils import get_current_device
from torch.cuda.amp import custom_bwd, custom_fwd
def matmul_2d(a,
@ -60,6 +61,7 @@ class Matmul_AB_2D(torch.autograd.Function):
"""Matrix multiplication for :math:`C = AB`
"""
@staticmethod
@custom_fwd(cast_inputs=torch.float16)
def forward(ctx: Any,
A: Tensor,
B: Tensor,
@ -120,32 +122,32 @@ class Matmul_AB_2D(torch.autograd.Function):
return out
@staticmethod
@custom_bwd
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
A, B = ctx.saved_tensors
A_grad = Matmul_ABT_2D.forward(
None,
output_grad, B,
ctx.summa_dim, ctx.A_shape,
ctx.row_rank, ctx.col_rank,
ctx.row_parallel_mode,
ctx.col_parallel_mode,
ctx.data_parallel_rank,
ctx.pipeline_parallel_rank,
ctx.pipeline_parallel_size,
ctx.tensor_parallel_size
)
B_grad = Matmul_ATB_2D.forward(
None,
A, output_grad,
ctx.summa_dim, ctx.B_shape,
ctx.row_rank, ctx.col_rank,
ctx.row_parallel_mode,
ctx.col_parallel_mode,
ctx.data_parallel_rank,
ctx.pipeline_parallel_rank,
ctx.pipeline_parallel_size,
ctx.tensor_parallel_size
)
with torch.no_grad():
A_grad = Matmul_ABT_2D.apply(
output_grad, B,
ctx.summa_dim, ctx.A_shape,
ctx.row_rank, ctx.col_rank,
ctx.row_parallel_mode,
ctx.col_parallel_mode,
ctx.data_parallel_rank,
ctx.pipeline_parallel_rank,
ctx.pipeline_parallel_size,
ctx.tensor_parallel_size
)
B_grad = Matmul_ATB_2D.apply(
A, output_grad,
ctx.summa_dim, ctx.B_shape,
ctx.row_rank, ctx.col_rank,
ctx.row_parallel_mode,
ctx.col_parallel_mode,
ctx.data_parallel_rank,
ctx.pipeline_parallel_rank,
ctx.pipeline_parallel_size,
ctx.tensor_parallel_size
)
return A_grad, B_grad, None, None, None, None, None, None, None, None, None, None
@ -153,6 +155,7 @@ class Matmul_ABT_2D(torch.autograd.Function):
"""Matrix multiplication for :math:`C = AB^T`
"""
@staticmethod
@custom_fwd(cast_inputs=torch.float16)
def forward(ctx: Any,
A: Tensor,
B: Tensor,
@ -214,32 +217,33 @@ class Matmul_ABT_2D(torch.autograd.Function):
return out
@staticmethod
@custom_bwd
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
A, B = ctx.saved_tensors
A_grad = Matmul_AB_2D.forward(
None,
output_grad, B,
ctx.summa_dim, ctx.A_shape,
ctx.row_rank, ctx.col_rank,
ctx.row_parallel_mode,
ctx.col_parallel_mode,
ctx.data_parallel_rank,
ctx.pipeline_parallel_rank,
ctx.pipeline_parallel_size,
ctx.tensor_parallel_size
)
B_grad = Matmul_ATB_2D.forward(
None,
output_grad, A,
ctx.summa_dim, ctx.B_shape,
ctx.row_rank, ctx.col_rank,
ctx.row_parallel_mode,
ctx.col_parallel_mode,
ctx.data_parallel_rank,
ctx.pipeline_parallel_rank,
ctx.pipeline_parallel_size,
ctx.tensor_parallel_size
)
with torch.no_grad():
A_grad = Matmul_AB_2D.apply(
output_grad, B,
ctx.summa_dim, ctx.A_shape,
ctx.row_rank, ctx.col_rank,
ctx.row_parallel_mode,
ctx.col_parallel_mode,
ctx.data_parallel_rank,
ctx.pipeline_parallel_rank,
ctx.pipeline_parallel_size,
ctx.tensor_parallel_size
)
B_grad = Matmul_ATB_2D.apply(
output_grad, A,
ctx.summa_dim, ctx.B_shape,
ctx.row_rank, ctx.col_rank,
ctx.row_parallel_mode,
ctx.col_parallel_mode,
ctx.data_parallel_rank,
ctx.pipeline_parallel_rank,
ctx.pipeline_parallel_size,
ctx.tensor_parallel_size
)
return A_grad, B_grad, None, None, None, None, None, None, None, None, None, None
@ -247,6 +251,7 @@ class Matmul_ATB_2D(torch.autograd.Function):
"""Matrix multiplication for :math:`C = A^TB`
"""
@staticmethod
@custom_fwd(cast_inputs=torch.float16)
def forward(ctx: Any,
A: Tensor,
B: Tensor,
@ -308,32 +313,33 @@ class Matmul_ATB_2D(torch.autograd.Function):
return out
@staticmethod
@custom_bwd
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
A, B = ctx.saved_tensors
A_grad = Matmul_ABT_2D.forward(
None,
B, output_grad,
ctx.summa_dim, ctx.A_shape,
ctx.row_rank, ctx.col_rank,
ctx.row_parallel_mode,
ctx.col_parallel_mode,
ctx.data_parallel_rank,
ctx.pipeline_parallel_rank,
ctx.pipeline_parallel_size,
ctx.tensor_parallel_size
)
B_grad = Matmul_AB_2D.forward(
None,
A, output_grad,
ctx.summa_dim, ctx.B_shape,
ctx.row_rank, ctx.col_rank,
ctx.row_parallel_mode,
ctx.col_parallel_mode,
ctx.data_parallel_rank,
ctx.pipeline_parallel_rank,
ctx.pipeline_parallel_size,
ctx.tensor_parallel_size
)
with torch.no_grad():
A_grad = Matmul_ABT_2D.apply(
B, output_grad,
ctx.summa_dim, ctx.A_shape,
ctx.row_rank, ctx.col_rank,
ctx.row_parallel_mode,
ctx.col_parallel_mode,
ctx.data_parallel_rank,
ctx.pipeline_parallel_rank,
ctx.pipeline_parallel_size,
ctx.tensor_parallel_size
)
B_grad = Matmul_AB_2D.apply(
A, output_grad,
ctx.summa_dim, ctx.B_shape,
ctx.row_rank, ctx.col_rank,
ctx.row_parallel_mode,
ctx.col_parallel_mode,
ctx.data_parallel_rank,
ctx.pipeline_parallel_rank,
ctx.pipeline_parallel_size,
ctx.tensor_parallel_size
)
return A_grad, B_grad, None, None, None, None, None, None, None, None, None, None
@ -341,6 +347,7 @@ class Add_Bias_2D(torch.autograd.Function):
"""Matrix add bias: :math:`C = A + b`
"""
@staticmethod
@custom_fwd(cast_inputs=torch.float16)
def forward(ctx: Any,
input: Tensor,
bias: Tensor,
@ -384,6 +391,7 @@ class Add_Bias_2D(torch.autograd.Function):
return output
@staticmethod
@custom_bwd
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
row_rank = ctx.row_rank
col_rank = ctx.col_rank
@ -423,6 +431,7 @@ class Add_Bias_2D(torch.autograd.Function):
class _LayerNorm_2D(torch.autograd.Function):
@staticmethod
@custom_fwd(cast_inputs=torch.float32)
def forward(ctx: Any,
input: Tensor,
E_x: Tensor,
@ -440,6 +449,7 @@ class _LayerNorm_2D(torch.autograd.Function):
return output
@staticmethod
@custom_bwd
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
row_parallel_mode = ctx.row_parallel_mode
col_parallel_mode = ctx.col_parallel_mode
@ -492,6 +502,7 @@ class _LayerNorm_2D(torch.autograd.Function):
class _ViT_Split_Input_2D(torch.autograd.Function):
@staticmethod
@custom_fwd(cast_inputs=torch.float16)
def forward(ctx: Any,
inputs: Tensor,
batch_size: int,
@ -509,6 +520,7 @@ class _ViT_Split_Input_2D(torch.autograd.Function):
return output
@staticmethod
@custom_bwd
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
# output_grad: [b/q, s, h/q]
# grads: [b, s, h/q]

View File

@ -1,5 +1,5 @@
from .cosine import CosineAnnealingLR, CosineAnnealingWarmupLR, FlatAnnealingLR, FlatAnnealingWarmupLR
from .linear import LinearWarmupLR, LinearWarmupDecay
from .linear import LinearWarmupLR
from .multistep import MultiStepLR, MultiStepWarmupLR
from .onecycle import OneCycleLR
from .poly import PolynomialLR, PolynomialWarmupLR

View File

@ -66,11 +66,10 @@ class CosineAnnealingWarmupLR(WarmupScheduler):
:type last_epoch: int, optional
"""
def __init__(self, optimizer, total_steps: int, warmup_steps: int = 0, eta_min: int = 0, last_epoch: int = -1,
**kwargs):
def __init__(self, optimizer, total_steps: int, warmup_steps: int = 0, eta_min: int = 0, last_epoch: int = -1):
base_scheduler = _CosineAnnealingLR(
optimizer, total_steps - warmup_steps, eta_min=eta_min)
super().__init__(optimizer, warmup_steps, base_scheduler, last_epoch=last_epoch)
optimizer, total_steps - warmup_steps, eta_min=eta_min, last_epoch=last_epoch)
super().__init__(optimizer, warmup_steps, base_scheduler)
@LR_SCHEDULERS.register_module

View File

@ -55,7 +55,7 @@ class DelayerScheduler(_LRScheduler):
class WarmupScheduler(_LRScheduler):
""" Starts with a linear warmup lr schedule until it reaches N epochs the applies a scheduler
""" Starts with a linear warmup lr schedule until it reaches N epochs the applies a scheduler
:param optimizer: Wrapped optimizer.
:type optimizer: torch.optim.Optimizer
@ -66,11 +66,8 @@ class WarmupScheduler(_LRScheduler):
:param last_epoch: The index of last epoch, defaults to -1
:type last_epoch: int, optional
"""
def __init__(self, optimizer, warmup_epochs, after_scheduler, last_epoch=-1):
if warmup_epochs < 0:
raise ValueError(f'warmup_epochs must >= 0, got {warmup_epochs}')
self.warmup_epochs = warmup_epochs
self.warmup_epochs = int(warmup_epochs)
self.after_scheduler = after_scheduler
self.finished = False
super().__init__(optimizer, last_epoch)
@ -79,14 +76,10 @@ class WarmupScheduler(_LRScheduler):
if self.last_epoch >= self.warmup_epochs:
if not self.finished:
self.after_scheduler.base_lrs = self.base_lrs
# reset lr to base_lr
for group, base_lr in zip(self.optimizer.param_groups, self.base_lrs):
group['lr'] = base_lr
self.finished = True
with _enable_get_lr_call(self.after_scheduler):
return self.after_scheduler.get_lr()
return self.after_scheduler.get_lr()
return [(self.last_epoch + 1) / (self.warmup_epochs + 1) * lr for lr in self.base_lrs]
return [(self.last_epoch + 1) / self.warmup_epochs * lr for lr in self.base_lrs]
def step(self, epoch=None):
if self.finished:

View File

@ -28,18 +28,3 @@ class LinearWarmupLR(_LRScheduler):
else:
return [(self.total_steps - self.last_epoch) / (self.total_steps - self.warmup_steps) * lr for lr in
self.base_lrs]
@LR_SCHEDULERS.register_module
class LinearWarmupDecay(_LRScheduler):
def __init__(self, optimizer, total_steps: int, warmup_steps: int = 0, last_epoch: int = -1, **kwargs):
self.warmup_steps = int(warmup_steps)
self.total_steps = total_steps
super().__init__(optimizer, last_epoch=last_epoch)
def get_lr(self):
if self.last_epoch < self.warmup_steps:
return [(self.last_epoch + 1) / self.warmup_steps * lr for lr in self.base_lrs]
else:
return [(self.total_steps - self.last_epoch - 1) / (self.total_steps - self.warmup_steps) * lr for lr in
self.base_lrs]

View File

@ -27,12 +27,7 @@ class MultiStepLR(_MultiStepLR):
:type last_epoch: int, optional
"""
def __init__(self, optimizer, total_steps: int, milestones: List[int] = None, gamma: float = 0.1,
num_steps_per_epoch: int = -1, last_epoch: int = -1, **kwargs):
if num_steps_per_epoch <= 0:
raise ValueError(
f'num_steps_per_epoch must > 0, got {num_steps_per_epoch}')
milestones = [v * num_steps_per_epoch for v in milestones]
def __init__(self, optimizer, total_steps: int, milestones: List[int] = None, gamma: float = 0.1, last_epoch: int = -1, **kwargs):
super().__init__(optimizer, milestones, gamma=gamma, last_epoch=last_epoch)
@ -57,14 +52,11 @@ class MultiStepWarmupLR(WarmupScheduler):
"""
def __init__(self, optimizer, total_steps: int, warmup_steps: int = 0, milestones: List[int] = None,
gamma: float = 0.1, num_steps_per_epoch: int = -1, last_epoch: int = -1, **kwargs):
gamma: float = 0.1, last_epoch: int = -1, **kwargs):
if len(milestones) == 0:
raise ValueError('milestones cannot be empty')
if num_steps_per_epoch <= 0:
raise ValueError(
f'num_steps_per_epoch must > 0, got {num_steps_per_epoch}')
milestones = [v * num_steps_per_epoch - warmup_steps for v in milestones if v *
num_steps_per_epoch >= warmup_steps]
milestones = [
v - warmup_steps for v in milestones if v >= warmup_steps]
base_scheduler = _MultiStepLR(optimizer, milestones=milestones,
gamma=gamma)
super().__init__(optimizer, warmup_steps, base_scheduler, last_epoch=last_epoch)

View File

@ -1,7 +1,7 @@
from torch.optim.lr_scheduler import LambdaLR as _LambdaLR
from torch.optim.lr_scheduler import MultiplicativeLR as _MultiplicativeLR
from torch.optim.lr_scheduler import StepLR as _StepLR
from torch.optim.lr_scheduler import _LRScheduler
from torch.optim.lr_scheduler import ExponentialLR as _ExponentialLR
from colossalai.registry import LR_SCHEDULERS
@ -25,11 +25,8 @@ class LambdaLR(_LambdaLR):
:type last_epoch: int, optional
"""
def __init__(self, optimizer, total_steps, lr_lambda=None, num_steps_per_epoch: int = -1,
last_epoch: int = -1) -> None:
def func(step): return lr_lambda(step // num_steps_per_epoch)
super().__init__(optimizer, func, last_epoch=last_epoch)
def __init__(self, optimizer, total_steps, lr_lambda=None, last_epoch: int = -1) -> None:
super().__init__(optimizer, lr_lambda, last_epoch=last_epoch)
@LR_SCHEDULERS.register_module
@ -51,11 +48,8 @@ class MultiplicativeLR(_MultiplicativeLR):
:type last_epoch: int, optional
"""
def __init__(self, optimizer, total_steps, lr_lambda=None, num_steps_per_epoch: int = -1,
last_epoch: int = -1) -> None:
def func(step): return lr_lambda(step // num_steps_per_epoch)
super().__init__(optimizer, func, last_epoch=last_epoch)
def __init__(self, optimizer, total_steps, lr_lambda=None, last_epoch: int = -1) -> None:
super().__init__(optimizer, lr_lambda, last_epoch=last_epoch)
@LR_SCHEDULERS.register_module
@ -79,14 +73,13 @@ class StepLR(_StepLR):
:type last_epoch: int, optional
"""
def __init__(self, optimizer, total_steps, step_size: int = 1, gamma: float = 0.1, num_steps_per_epoch: int = -1,
last_epoch: int = -1) -> None:
super().__init__(optimizer, step_size * num_steps_per_epoch,
def __init__(self, optimizer, total_steps, step_size: int = 1, gamma: float = 0.1, last_epoch: int = -1) -> None:
super().__init__(optimizer, step_size,
gamma=gamma, last_epoch=last_epoch)
@LR_SCHEDULERS.register_module
class ExponentialLR(_LRScheduler):
class ExponentialLR(_ExponentialLR):
"""Decays the learning rate of each parameter group by gamma every epoch.
When last_epoch=-1, sets initial lr as lr
@ -102,21 +95,6 @@ class ExponentialLR(_LRScheduler):
:type last_epoch: int, optional
"""
def __init__(self, optimizer, total_steps, gamma: float = 1.0, num_steps_per_epoch: int = -1,
def __init__(self, optimizer, total_steps, gamma: float = 1.0,
last_epoch: int = -1) -> None:
self.gamma = gamma
self.num_steps_per_epoch = num_steps_per_epoch
super().__init__(optimizer, last_epoch=last_epoch)
def get_lr(self):
if self.last_epoch == 0:
return self.base_lrs
elif (self.last_epoch + 1) % self.num_steps_per_epoch == 0:
return [group['lr'] * self.gamma
for group in self.optimizer.param_groups]
return [group['lr']
for group in self.optimizer.param_groups]
def _get_closed_form_lr(self):
return [base_lr * self.gamma ** (self.last_epoch // self.num_steps_per_epoch)
for base_lr in self.base_lrs]
super().__init__(optimizer, gamma, last_epoch=last_epoch)

View File

@ -106,7 +106,7 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
tensor_parallel_norm = _calc_lp(tensor_parallel_grads, norm_type)
no_tensor_parallel_grads = _calc_lp(
no_tensor_parallel_grads, norm_type)
if gpc.is_initialized(ParallelMode.TENSOR):
if gpc.is_initialized(ParallelMode.TENSOR) and len(tensor_parallel_grads) > 0:
# Sum across all model-parallel GPUs.
torch.distributed.all_reduce(tensor_parallel_norm,
op=torch.distributed.ReduceOp.SUM,

View File

@ -6,6 +6,7 @@ import math
import torch
import torch.distributed as dist
try:
from deepspeed.git_version_info import version
from deepspeed.moe.utils import is_moe_param
@ -13,7 +14,7 @@ try:
from deepspeed.ops.op_builder import UtilsBuilder
from deepspeed.runtime.zero.config import ZERO_OPTIMIZATION_GRADIENTS
except ImportError:
print('DeepSpeed is required if you want to use ZeRO.')
pass
from packaging import version as pkg_version
from torch._six import inf
from torch.distributed.distributed_c10d import _get_global_rank
@ -251,7 +252,7 @@ class ZeroRedundancyOptimizer_Level_2(Optimizer):
self.nccl_start_alignment_factor = 2
assert (
allgather_bucket_size % self.nccl_start_alignment_factor == 0), f"allgather_bucket_size must be a multiple of nccl_start_alignment_factor, {self.nccl_start_alignment_factor} "
allgather_bucket_size % self.nccl_start_alignment_factor == 0), f"allgather_bucket_size must be a multiple of nccl_start_alignment_factor, {self.nccl_start_alignment_factor} "
self.all_reduce_print = False
self.dtype = self.optimizer.param_groups[0]['params'][0].dtype
@ -759,7 +760,7 @@ class ZeroRedundancyOptimizer_Level_2(Optimizer):
elif start_index > current_index and start_index < (current_index +
param_size):
assert (
first_offset == 0), "This can happen either zero or only once as this must be the first tensor in the partition"
first_offset == 0), "This can happen either zero or only once as this must be the first tensor in the partition"
first_offset = start_index - current_index
set_key_value_list(self.param_to_partition_ids[i],
@ -803,7 +804,7 @@ class ZeroRedundancyOptimizer_Level_2(Optimizer):
def report_ipg_memory_usage(self, tag, param_elems):
elem_count = self.elements_in_ipg_bucket + param_elems
percent_of_bucket_size = (
100.0 * elem_count) // self.reduce_bucket_size
100.0 * elem_count) // self.reduce_bucket_size
if self.verbose:
report_memory_usage(
f"{tag}: elems in_bucket {self.elements_in_ipg_bucket} param {param_elems} max_percent {percent_of_bucket_size}"
@ -1491,7 +1492,7 @@ class ZeroRedundancyOptimizer_Level_2(Optimizer):
params_in_partition.append(tensor)
assert (
first_offset == 0), "This can happen either zero or only once as this must be the first tensor in the partition"
first_offset == 0), "This can happen either zero or only once as this must be the first tensor in the partition"
first_offset = start_index - current_index
else:
@ -1799,7 +1800,7 @@ class ZeroRedundancyOptimizer_Level_2(Optimizer):
num_elements = shard_size
assert shard_size * \
num_shards <= partitioned_params[partition_id].numel()
num_shards <= partitioned_params[partition_id].numel()
for shard_id in range(num_shards):
@ -2248,7 +2249,7 @@ def estimate_zero2_model_states_mem_needs(total_params,
if cpu_offload:
gpu_mem = 2 * total_params
cpu_mem = total_params * \
max(4 * total_gpus, 16) * additional_buffer_factor
max(4 * total_gpus, 16) * additional_buffer_factor
else:
gpu_mem = 4 * total_params + int(16 * total_params / total_gpus)
cpu_mem = total_params * 4 * num_gpus_per_node * additional_buffer_factor

View File

@ -21,7 +21,7 @@ try:
from deepspeed.runtime.zero.partition_parameters import *
from deepspeed.runtime.zero.partition_parameters import _init_external_params
except ImportError:
print('DeepSpeed is required if you want to use ZeRO.')
pass
from torch._six import inf
from torch.distributed.distributed_c10d import _get_global_rank

View File

@ -20,3 +20,4 @@ TRANSFORMS = Registry('transforms', third_party_library=[transforms])
PIPE_ALLOC_POLICY = Registry('pipeline_allocation_policy')
SAMPLERS = Registry('samplers')
LR_SCHEDULERS = Registry('lr_schedulers')
SCHEDULE = Registry('schedules')

View File

@ -1,5 +1,5 @@
from ._trainer import Trainer
from .hooks import *
from .metric import Loss, Accuracy2D, Accuracy3D, Accuracy2p5D
from .metric import Loss, Accuracy2D, Accuracy3D, Accuracy2p5D, LearningRate
__all__ = ['Trainer', 'Loss', 'Accuracy3D', 'Accuracy2D', 'Accuracy2p5D']
__all__ = ['Trainer', 'Loss', 'Accuracy3D', 'Accuracy2D', 'Accuracy2p5D', 'LearningRate']

View File

@ -1,7 +1,6 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from typing import Optional
from typing import Union, List
import torch
@ -10,12 +9,11 @@ from torch.utils.data import DataLoader
from tqdm import tqdm
from colossalai.builder import build_hooks
from colossalai.checkpointing import save_checkpoint, load_checkpoint, get_checkpoint_path
from colossalai.context import Config
from colossalai.engine import Engine
from colossalai.logging import get_global_dist_logger
from colossalai.utils import get_global_multitimer, is_dp_rank_0, is_tp_rank_0, is_no_pp_or_last_stage
from colossalai.nn.data import DataParallelSampler
from colossalai.utils import MultiTimer
from colossalai.utils import is_dp_rank_0, is_tp_rank_0, is_no_pp_or_last_stage
class Trainer:
@ -30,43 +28,31 @@ class Trainer:
:type hoooks_cfg: Config, optional
:type verbose: bool, optional
"""
def __init__(self,
engine: Engine,
hooks_cfg: Optional[Config] = None,
verbose: bool = False):
verbose: bool = False,
timer: MultiTimer = None):
# training-ralated params
self._engine = engine
self._max_epochs = float('inf')
self._max_steps = float('inf')
self._max_epochs = 0
self._cur_epoch = 0
self._max_steps = 0
self._cur_step = 0
# data-related params
self._train_dataloader = None
self._test_dataloader = None
self._steps_per_epoch = 0
# misc params
self._display_progress = False
self._logger = get_global_dist_logger()
self._verbose = verbose
# hooks can store states in this dict, and could be consumed by other hooks
self.states = {}
self.states = dict()
# build hooks
self.hooks = list()
if hooks_cfg is not None:
for cfg in hooks_cfg:
hook = build_hooks(cfg, self)
self.hooks.append(hook)
self.hooks.sort(key=lambda hook: hook.priority)
if self._verbose:
for hook in self.hooks:
self._logger.info(
f'build {hook.__class__.__name__} for train, priority = {hook.priority}', ranks=[0])
# timer
self._timer = get_global_multitimer()
# multi-timer for time benchmarking
self._timer = timer
@property
def cur_epoch(self):
@ -74,13 +60,65 @@ class Trainer:
"""
return self._cur_epoch
@cur_epoch.setter
def cur_epoch(self, epoch: int):
"""Set how many epochs have been processed.
"""
# allow setter for training resumption
self._cur_epoch = epoch
@property
def cur_step(self):
"""Returns how many iteration steps have been processed.
"""
return self._cur_step
def call_hooks(self, func, output=None):
@property
def max_epochs(self):
return self._max_epochs
@property
def max_steps(self):
return self._max_steps
@property
def steps_per_epoch(self):
return self._steps_per_epoch
@property
def engine(self):
return self._engine
@engine.setter
def engine(self, engine_: Engine):
self._engine = engine_
def _set_current_step(self, epoch: int):
"""Sets current step number.
:param epoch: Step number to be set
:type epoch: int
"""
self._cur_step = epoch * self._steps_per_epoch
def _call_timer(self, action: str, item: str, *args, **kwargs) -> None:
"""Call timer funciton with a given timer name.
:param action: Function to be called on timer
:type action: str
:param item: Name of the timer
:type item: str
"""
if self._timer is not None:
getattr(self._timer, action)(item, *args, **kwargs)
def _reset_states(self) -> None:
"""Clear trainer states
"""
self.states = dict()
def _call_hooks(self, func, output=None):
"""Calls specific hooks in the current time point.
:param func: A string represents the time point
@ -95,161 +133,186 @@ class Trainer:
else:
getattr(hook, func)(*output)
def exceed_max_step(self):
"""Checks whether the trainer exceeds the maximum number of runnning iterations.
@staticmethod
def _should_display_progress(display_progress: bool):
""" Only display progress on DP rank 0, TP rank 0 and PP last rank
"""
return self._cur_step >= self._max_steps
return display_progress and is_dp_rank_0() and is_tp_rank_0() and is_no_pp_or_last_stage()
def set_epoch(self, epoch):
"""Sets current epoch number.
:param epoch: Epoch number to be set
:type epoch: int
"""
self._cur_epoch = epoch
def _recover_steps(self):
step = self.cur_step * self._engine.schedule.num_steps
self._cur_step = step
def _set_display_progress(self, display_progress: bool):
self._display_progress = display_progress and is_dp_rank_0(
) and is_tp_rank_0() and is_no_pp_or_last_stage()
def _train_epoch(self, epoch: int = None):
def _train_epoch(self,
train_dataloader: DataLoader,
epoch: int = None,
display_progress: bool = False):
# set sampler epoch
if epoch is not None and \
hasattr(self._engine.train_dataloader, 'sampler') and \
isinstance(self._engine.train_dataloader.sampler, DataParallelSampler):
self._engine.train_dataloader.sampler.set_epoch(epoch)
hasattr(train_dataloader, 'sampler') and \
isinstance(train_dataloader.sampler, DataParallelSampler):
train_dataloader.sampler.set_epoch(epoch)
# set training state
self._engine.train()
progress = range(self._engine.schedule.num_steps)
if self._display_progress:
data_iter = iter(train_dataloader)
progress = range(self._steps_per_epoch)
if display_progress:
if epoch is None:
progress = tqdm(progress, desc='[Train]')
else:
progress = tqdm(progress, desc=f'[Epoch {epoch} train]')
# train 1 epoch
self.call_hooks('before_train_epoch')
self._timer.start('train-epoch')
for _ in progress:
self._call_hooks('before_train_epoch')
self._call_timer(action='start', item='train-epoch')
for i in progress:
self._call_hooks('before_train_iter')
self._call_timer(action='start', item='train-step')
if i == self._steps_per_epoch - 1:
is_last_iteration = True
else:
is_last_iteration = False
# run 1 training step
logits, label, loss = self._engine.step(data_iter, is_last_iteration)
self._call_timer(action='stop', item='train-step', keep_in_history=True)
self._call_hooks('after_train_iter', output=(logits, label, loss))
self._cur_step += 1
self.call_hooks('before_train_iter')
self._timer.start('train-step')
logits, label, loss = self._engine.step()
self._timer.stop('train-step', keep_in_history=True)
self.call_hooks('after_train_iter', output=(logits, label, loss))
if self.exceed_max_step():
# stop when max iter is reached
# stop when max iter is reached
if self._exceed_max_step():
break
self._timer.stop('train-epoch', keep_in_history=True)
self.call_hooks('after_train_epoch')
self._timer.reset('train-step')
self._call_timer(action='stop', item='train-epoch', keep_in_history=True)
self._call_hooks('after_train_epoch')
self._call_timer(action='reset', item='train-step')
def _eval(self,
test_dataloader: DataLoader,
epoch: int = None,
return_loss: bool = True):
display_progress: bool = False):
# switch engine status
self._engine.eval()
self.call_hooks('before_test')
data_iter = iter(test_dataloader)
num_steps = len(test_dataloader)
self._call_hooks('before_test')
with torch.no_grad():
# prepare progress bar
progress = range(self._engine.schedule.num_steps)
if self._display_progress:
progress = range(num_steps)
if display_progress:
desc = 'Evaluation'
if epoch is not None:
desc = '[Epoch %d val]' % epoch
progress = tqdm(progress, desc=desc)
self.call_hooks('before_test_epoch')
self._timer.start('test-epoch')
self._call_hooks('before_test_epoch')
self._call_timer(action='start', item='test-epoch')
for _ in progress:
self.call_hooks('before_test_iter')
self._timer.start('test-step')
logits, label, loss = self._engine.step(
return_loss=return_loss)
self._timer.stop('test-step', keep_in_history=True)
self.call_hooks('after_test_iter',
output=(logits, label, loss))
self._timer.stop('test-epoch', keep_in_history=True)
self.call_hooks('after_test_epoch')
self.call_hooks('after_test')
self._timer.reset('test-step')
self._timer.reset('test-epoch')
self._call_hooks('before_test_iter')
self._call_timer(action='start', item='test-step')
logits, label, loss = self._engine.step(data_iter, return_loss=True)
self._call_timer(action='stop', item='test-step', keep_in_history=True)
self._call_hooks('after_test_iter',
output=(logits, label, loss))
self._call_timer(action='stop', item='test-epoch', keep_in_history=True)
self._call_hooks('after_test_epoch')
self._call_hooks('after_test')
self._call_timer(action='reset', item='test-step')
self._call_timer(action='reset', item='test-epoch')
def _exceed_max_step(self):
return self._max_steps is not None and self._cur_step > self._max_steps
def fit(self,
train_dataloader: DataLoader,
test_dataloader: DataLoader = None,
max_epochs: int = None,
epochs: int,
max_steps: int = None,
test_dataloader: DataLoader = None,
test_interval: int = 1,
display_progress: bool = False):
hooks_cfg: dict = None,
display_progress: bool = False,
):
"""Trains the model to fit training data.
:param train_dataloader: DataLoader in training
:param test_dataloader: DataLoader in testing
:param max_epochs: Maximum number of epoches
:param epochs: Maximum number of epoches
:param max_steps: Maximum number of running iterations
:param test_dataloader: DataLoader in testing
:param test_interval: Interval of testing
:param hooks_cfg: A list of hook configuration
:param display_progress: If True, the training progress will be printed
:type train_dataloader: DataLoader
:type test_dataloader: DataLoader
:type max_epochs: int
:type epochs: int
:type max_steps: int
:type test_dataloader: DataLoader
:type test_interval: int
:type hooks_cfg: dict
:type display_progress: bool
:type gradient_accumulation: int
"""
# prepare dataloaders
self._train_dataloader = train_dataloader
self._engine.set_dataloader(self._train_dataloader, train=True)
self._engine.train()
# set epochs and steps, consider gradient accumulation
self._steps_per_epoch = len(train_dataloader) // self._engine.gradient_accumulation
self._max_steps = max_steps
self._max_epochs = epochs
# check if testing is required
should_test = False
if test_dataloader is not None:
self._test_dataloader = test_dataloader
self._engine.set_dataloader(self._test_dataloader, train=False)
should_test = True
# decide the
if max_epochs is not None:
self._max_epochs = max_epochs
if max_steps is not None:
self._max_steps = max_steps
self._set_display_progress(display_progress)
display_progress = self._should_display_progress(display_progress)
# reset hooks
self._reset_states()
self.hooks = list()
# build hooks
if hooks_cfg is not None:
for cfg in hooks_cfg:
hook = build_hooks(cfg, self)
self.hooks.append(hook)
self.hooks.sort(key=lambda hook: hook.priority)
if self._verbose:
for hook in self.hooks:
self._logger.info(
f'build {hook.__class__.__name__} for training, priority = {hook.priority}', ranks=[0])
self._logger.info("Lower value means higher priority for calling hook function")
# start train
self.call_hooks('before_train')
self._engine.train()
self._call_hooks('before_train')
# recover step value if resuming training
if self.cur_epoch != 0:
self._recover_steps()
last_epoch = self._cur_epoch
if self.cur_epoch != 0:
self._set_current_step(last_epoch)
for epoch in range(last_epoch, self._max_epochs):
self._cur_epoch += 1
for epoch in range(last_epoch, epochs):
# train for one epoch
self._train_epoch(epoch)
self._train_epoch(
train_dataloader=train_dataloader,
epoch=epoch,
display_progress=display_progress
)
# start eval
if should_test and epoch % test_interval == 0:
self._eval(epoch, return_loss=True)
self._eval(test_dataloader=test_dataloader,
display_progress=display_progress,
epoch=epoch,
)
self._cur_epoch += 1
# check for termination
if self.exceed_max_step():
if self._exceed_max_step():
self._logger.info(
f"Max number of steps {self._max_steps} has been reached, training is stopped automatically")
f"Max number of steps {max_steps} has been reached, training is stopped automatically")
break
self.call_hooks('after_train')
self._timer.reset('train-epoch')
self._call_hooks('after_train')
self._call_timer('reset', 'train-epoch')
def evaluate(self,
test_dataloader: DataLoader,
@ -261,15 +324,13 @@ class Trainer:
:type test_dataloader: DataLoader
:type display_progress: bool, optional
"""
# set dataloader
self._test_dataloader = test_dataloader
self._engine.set_dataloader(self._test_dataloader, train=True)
# set
self._set_display_progress(display_progress)
# set display
display_progress = self._should_display_progress(display_progress)
# eval
self._eval(return_loss=True)
self._eval(test_dataloader=test_dataloader,
display_progress=display_progress,
)
def predict(self, data: Union[Tensor, List[Tensor]]):
"""Uses trained model to make a prediction for a tensor or a tensor list.
@ -289,45 +350,6 @@ class Trainer:
# prepare a list of (data, label) to make it iterable
# for compatibility with schedule
simple_dataloader = [(data, None)]
self._engine.set_dataloader(simple_dataloader)
output, _, _ = self._engine.step(return_loss=False)
data_iter = iter(simple_dataloader)
output, _, _ = self._engine.step(data_iter, return_loss=False)
return output
def save(self, path: str, suffix: str = ''):
"""Saves the model to a file.
:param path: Relative path of the file
:param suffix: Suffix of the file
:type path: str
:type suffix: str, optional
"""
save_path = get_checkpoint_path(path,
self._cur_epoch,
suffix=suffix)
save_checkpoint(save_path, self._cur_epoch, self._engine.get_model(),
self._engine.get_optimizer(),
self._engine.get_lr_scheduler())
def load(self,
path: str,
finetune: bool = False,
strict: bool = False):
"""Loads parameters to the model from a file.
:param path: Relative path of the file
:param finetune: Whether allows to load a part of the model
:param strict: Whether loads a model that has the same shape of parameters
:type path: str
:type finetune: bool, optional
:type strict: bool, optional
"""
last_epoch, _ = load_checkpoint(path,
self._engine.get_model(),
self._engine.get_optimizer(),
self._engine.get_lr_scheduler(),
finetune=finetune,
strict=strict)
if finetune:
self.set_epoch(0)
else:
self.set_epoch(last_epoch)

View File

@ -2,10 +2,12 @@ from ._base_hook import BaseHook
from ._checkpoint_hook import SaveCheckpointHook, LoadCheckpointHook
from ._metric_hook import LossHook, Accuracy2DHook, AccuracyHook, MetricHook
from ._log_hook import LogMetricByEpochHook, TensorboardHook, LogTimingByEpochHook, LogMemoryByEpochHook
from ._lr_scheduler_hook import LRSchedulerHook
__all__ = [
'BaseHook', 'MetricHook',
'LoadCheckpointHook', 'SaveCheckpointHook',
'LossHook', 'AccuracyHook', 'Accuracy2DHook',
'LogMetricByEpochHook', 'TensorboardHook', 'LogTimingByEpochHook', 'LogMemoryByEpochHook',
'LRSchedulerHook'
]

View File

@ -3,13 +3,13 @@
import os.path as osp
import torch.distributed as dist
from colossalai.checkpointing import get_latest_checkpoint_path, get_checkpoint_path
from colossalai.registry import HOOKS
from colossalai.trainer.hooks import BaseHook
from colossalai.trainer import Trainer
from colossalai.trainer.hooks import BaseHook
from colossalai.utils import is_dp_rank_0
from colossalai.utils.checkpointing import get_latest_checkpoint_path, get_checkpoint_path
from colossalai.utils.checkpointing import save_checkpoint, load_checkpoint
from ._lr_scheduler_hook import LRSchedulerHook
@HOOKS.register_module
@ -33,7 +33,7 @@ class SaveCheckpointHook(BaseHook):
interval: int = 1,
checkpoint_dir: str = None,
suffix: str = '',
priority: int = 0):
priority: int = 10):
super().__init__(trainer=trainer, priority=priority)
assert isinstance(trainer, Trainer), \
f'SaveCheckpointHook expects a Trainer, got {type(trainer)}'
@ -41,6 +41,16 @@ class SaveCheckpointHook(BaseHook):
self.checkpoint_dir = checkpoint_dir
self.suffix = suffix
# get lr scheduler from the LRSchedulerHook before train
self._lr_scheduler = None
def before_train(self):
# check if lr scheduler is present in LRSchedulerHook
for hook in self.trainer.hooks:
if isinstance(hook, LRSchedulerHook):
self._lr_scheduler = hook.lr_scheduler
break
def after_train_epoch(self):
"""Saves the model after a training epoch.
"""
@ -48,14 +58,18 @@ class SaveCheckpointHook(BaseHook):
if self.trainer.cur_epoch % self.interval == 0:
# only gpus with data parallel rank equals to 0 write to the disk
if is_dp_rank_0():
self.trainer.save(path=self.checkpoint_dir, suffix=self.suffix)
save_path = get_checkpoint_path(self.checkpoint_dir,
self.trainer.cur_epoch,
suffix=self.suffix)
save_checkpoint(save_path,
self.trainer.cur_epoch,
self.trainer.engine.model,
self.trainer.engine.optimizer,
self._lr_scheduler)
self.logger.info(
f'checkpoint for epoch {self.trainer.cur_epoch} is saved to {self.checkpoint_dir}')
# wait until everyone is done
if dist.is_initialized():
dist.barrier()
@HOOKS.register_module
class LoadCheckpointHook(BaseHook):
@ -81,30 +95,46 @@ class LoadCheckpointHook(BaseHook):
epoch: int = -1,
finetune: bool = False,
strict: bool = False,
priority: int = 10) -> None:
suffix: str = '',
priority: int = 0) -> None:
super().__init__(trainer=trainer, priority=priority)
assert isinstance(trainer, Trainer), \
f'LoadLatestCheckpointHook excepts a Trainer, got {type(trainer)}'
self.epoch = epoch
self.checkpoint_dir = checkpoint_dir
self.finetune = finetune
self.suffix = suffix
self.strict = strict
super().__init__(trainer=trainer, priority=priority)
def before_train(self):
"""Loads parameters to the model before training.
"""
# check if lr scheduler is present in LRSchedulerHook
lr_scheduler = None
for hook in self.trainer.hooks:
if isinstance(hook, LRSchedulerHook):
lr_scheduler = hook.lr_scheduler
break
# use latest checkpoint if epoch = -1
if self.epoch == -1:
path = get_latest_checkpoint_path(self.checkpoint_dir)
path = get_latest_checkpoint_path(self.checkpoint_dir, suffix=self.suffix)
else:
path = get_checkpoint_path(self.checkpoint_dir, epoch=self.epoch)
path = get_checkpoint_path(self.checkpoint_dir, epoch=self.epoch, suffix=self.suffix)
if osp.exists(path):
self.trainer.load(
path, finetune=self.finetune, strict=self.strict)
last_epoch, _ = load_checkpoint(path,
self.trainer.engine.model,
self.trainer.engine.optimizer,
lr_scheduler,
finetune=self.finetune,
strict=self.strict)
if self.finetune:
self.trainer.cur_epoch = 0
else:
self.trainer.cur_epoch = last_epoch
self.logger.info(
f'loaded checkpoint from {path}')
else:
raise FileNotFoundError(f'checkpoint is not found at {path}')
# Some utilities want to load a checkpoint without distributed being initialized
if dist.is_initialized():
dist.barrier()

View File

@ -5,7 +5,7 @@ import os
import os.path as osp
import torch
from tensorboardX import SummaryWriter
from torch.utils.tensorboard import SummaryWriter
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
@ -13,7 +13,7 @@ from colossalai.registry import HOOKS
from colossalai.trainer._trainer import Trainer
from colossalai.utils import get_global_multitimer, set_global_multitimer_status, report_memory_usage, is_dp_rank_0, \
is_tp_rank_0, is_no_pp_or_last_stage
from ._metric_hook import MetricHook
from ._base_hook import BaseHook
def _format_number(val):
@ -24,7 +24,7 @@ def _format_number(val):
return val
class EpochIntervalHook(MetricHook):
class EpochIntervalHook(BaseHook):
def __init__(self, trainer: Trainer, interval: int = 1, priority: int = 1):
super().__init__(trainer, priority)
self._interval = interval
@ -45,7 +45,7 @@ class LogMetricByEpochHook(EpochIntervalHook):
:type priority: int, optional
"""
def __init__(self, trainer: Trainer, interval: int = 1, priority: int = 1) -> None:
def __init__(self, trainer: Trainer, interval: int = 1, priority: int = 10) -> None:
super().__init__(trainer=trainer, interval=interval, priority=priority)
self._is_rank_to_log = is_dp_rank_0() and is_tp_rank_0() and is_no_pp_or_last_stage()
@ -74,7 +74,7 @@ class LogMetricByEpochHook(EpochIntervalHook):
@HOOKS.register_module
class TensorboardHook(MetricHook):
class TensorboardHook(BaseHook):
"""Specialized Hook to record the metric to Tensorboard.
:param trainer: Trainer attached with current hook
@ -85,59 +85,71 @@ class TensorboardHook(MetricHook):
:type priority: int, optional
"""
def __init__(self, trainer: Trainer, log_dir: str, priority: int = 1) -> None:
def __init__(self,
trainer: Trainer,
log_dir: str,
dp_rank_0_only: bool = True,
tp_rank_0_only: bool = True,
priority: int = 10,
) -> None:
super().__init__(trainer=trainer, priority=priority)
self._is_rank_to_log = is_no_pp_or_last_stage()
if self._is_rank_to_log:
# create log dir
if not gpc.is_initialized(ParallelMode.GLOBAL) or gpc.get_global_rank() == 0:
os.makedirs(log_dir, exist_ok=True)
# determine the ranks to generate tensorboard logs
self._is_valid_rank_to_log = is_no_pp_or_last_stage()
if dp_rank_0_only:
self._is_valid_rank_to_log = self._is_valid_rank_to_log and is_dp_rank_0()
if tp_rank_0_only:
self._is_valid_rank_to_log = self._is_valid_rank_to_log and is_tp_rank_0()
if self._is_valid_rank_to_log:
# create workspace on only one rank
if gpc.is_initialized(ParallelMode.GLOBAL):
rank = gpc.get_global_rank()
else:
rank = 0
log_dir = osp.join(log_dir, f'rank_{rank}')
# create workspace
if not osp.exists(log_dir):
os.makedirs(log_dir)
log_dir = osp.join(log_dir, f'rank_{rank}')
os.makedirs(log_dir, exist_ok=True)
self.writer = SummaryWriter(
log_dir=log_dir, filename_suffix=f'_rank_{rank}')
def after_train_iter(self, *args):
for metric_name, metric_calculator in self.trainer.states['metrics']['train'].items():
def _log_by_iter(self, mode: str):
for metric_name, metric_calculator in self.trainer.states['metrics'][mode].items():
if metric_calculator.epoch_only:
continue
val = metric_calculator.get_last_step_value()
if self._is_rank_to_log:
self.writer.add_scalar(
f'{metric_name}/train', val, self.trainer.cur_step)
def after_test_iter(self, *args):
for metric_name, metric_calculator in self.trainer.states['metrics']['test'].items():
if metric_calculator.epoch_only:
continue
val = metric_calculator.get_last_step_value()
if self._is_rank_to_log:
self.writer.add_scalar(f'{metric_name}/test', val,
if self._is_valid_rank_to_log:
self.writer.add_scalar(f'{metric_name}/{mode}', val,
self.trainer.cur_step)
def after_test_epoch(self):
for metric_name, metric_calculator in self.trainer.states['metrics']['test'].items():
def _log_by_epoch(self, mode: str):
for metric_name, metric_calculator in self.trainer.states['metrics'][mode].items():
if metric_calculator.epoch_only:
val = metric_calculator.get_accumulated_value()
if self._is_rank_to_log:
self.writer.add_scalar(f'{metric_name}/test', val,
if self._is_valid_rank_to_log:
self.writer.add_scalar(f'{metric_name}/{mode}', val,
self.trainer.cur_step)
def after_test_iter(self, *args):
self._log_by_iter(mode='test')
def after_test_epoch(self):
self._log_by_epoch(mode='test')
def after_train_iter(self, *args):
self._log_by_iter(mode='train')
def after_train_epoch(self):
for metric_name, metric_calculator in self.trainer.states['metrics']['train'].items():
if metric_calculator.epoch_only:
val = metric_calculator.get_accumulated_value()
if self._is_rank_to_log:
self.writer.add_scalar(f'{metric_name}/train', val,
self.trainer.cur_step)
self._log_by_epoch(mode='train')
@HOOKS.register_module
@ -157,7 +169,7 @@ class LogTimingByEpochHook(EpochIntervalHook):
def __init__(self,
trainer: Trainer,
interval: int = 1,
priority: int = 1,
priority: int = 10,
log_eval: bool = True
) -> None:
super().__init__(trainer=trainer, interval=interval, priority=priority)
@ -217,7 +229,7 @@ class LogMemoryByEpochHook(EpochIntervalHook):
def __init__(self,
trainer: Trainer,
interval: int = 1,
priority: int = 1,
priority: int = 10,
log_eval: bool = True
) -> None:
super().__init__(trainer=trainer, interval=interval, priority=priority)

View File

@ -0,0 +1,58 @@
from torch import Tensor
from colossalai.builder import build_lr_scheduler
from colossalai.registry import HOOKS
from ._metric_hook import MetricHook
from .._trainer import Trainer
from ..metric import LearningRate
@HOOKS.register_module
class LRSchedulerHook(MetricHook):
"""Build LR scheduler
:param trainer: Trainer attached with current hook
:type trainer: Trainer
:param lr_scheduler_cfg: The config of LR scheduler
:type lr_scheduler_cfg: dict
:param by_epoch: If `True`, the LR will be scheduled every epoch. Else, the LR will be scheduled every batch. Defaults to `True`.
:type by_epoch: bool
:param priority: Priority in the printing, hooks with small priority will be printed in front
:type priority: int, optional
"""
def __init__(self,
trainer: Trainer,
lr_scheduler_cfg: dict,
by_epoch: bool = True,
store_lr_in_state: bool = True,
priority: int = 1,
):
super().__init__(trainer=trainer, priority=priority)
self.by_epoch = by_epoch
if by_epoch:
total_steps = trainer.max_epochs
else:
total_steps = trainer.max_epochs * trainer.steps_per_epoch
if trainer.max_steps is not None:
total_steps = min(total_steps, trainer.max_steps)
lr_scheduler_cfg['total_steps'] = total_steps
self.lr_scheduler = build_lr_scheduler(
lr_scheduler_cfg, trainer.engine.optimizer)
if store_lr_in_state:
self.trainer.states['metrics']['train']['lr'] = LearningRate(epoch_only=by_epoch,
initial_lr=self.lr_scheduler.get_lr()[0])
def after_train_epoch(self):
if self.by_epoch:
self.lr_scheduler.step()
self.trainer.states['metrics']['train']['lr'].update(self.lr_scheduler.get_lr()[0])
def after_train_iter(self, output: Tensor, label: Tensor, loss: Tensor):
if not self.by_epoch:
self.lr_scheduler.step()
self.trainer.states['metrics']['train']['lr'].update(self.lr_scheduler.get_lr()[0])

View File

@ -21,9 +21,12 @@ class MetricHook(BaseHook):
:type priority: int
"""
def __init__(self, trainer: Trainer, priority: int):
def __init__(self,
trainer: Trainer,
priority: int,
):
super().__init__(trainer, priority)
self._is_stage_to_log = is_no_pp_or_last_stage()
self._is_stage_to_compute = is_no_pp_or_last_stage()
self._check_metric_states_initialization()
def _check_metric_states_initialization(self):
@ -41,33 +44,34 @@ class LossHook(MetricHook):
:type priority: int, optional
"""
def __init__(self, trainer: Trainer, priority: int = 10):
def __init__(self, trainer: Trainer, priority: int = 0):
super().__init__(trainer, priority)
if self._is_stage_to_log:
self.metric = Loss(epoch_only=False)
if self._is_stage_to_compute:
self.train_loss = Loss(epoch_only=False)
self.test_loss = Loss(epoch_only=True)
# register the metric calculator
self.trainer.states['metrics']['train'][
self.metric.__class__.__name__] = self.metric
self.train_loss.__class__.__name__] = self.train_loss
self.trainer.states['metrics']['test'][
self.metric.__class__.__name__] = self.metric
self.test_loss.__class__.__name__] = self.test_loss
def before_train_epoch(self):
if self._is_stage_to_log:
self.metric.reset()
if self._is_stage_to_compute:
self.train_loss.reset()
def after_train_iter(self, logits, label, loss):
if self._is_stage_to_log:
self.metric.update(loss)
if self._is_stage_to_compute:
self.train_loss.update(loss)
def before_test_epoch(self):
if self._is_stage_to_log:
self.metric.reset()
if self._is_stage_to_compute:
self.test_loss.reset()
def after_test_iter(self, logits, label, loss):
if self._is_stage_to_log:
self.metric.update(loss)
if self._is_stage_to_compute:
self.test_loss.update(loss)
@HOOKS.register_module
@ -81,10 +85,10 @@ class Accuracy2DHook(MetricHook):
:type priority: int, optional
"""
def __init__(self, trainer: Trainer, priority: int = 10):
def __init__(self, trainer: Trainer, priority: int = 0):
super().__init__(trainer, priority)
if self._is_stage_to_log:
if self._is_stage_to_compute:
self.metric = Accuracy2D(epoch_only=True)
# register the metric
@ -92,20 +96,20 @@ class Accuracy2DHook(MetricHook):
self.metric.__class__.__name__] = self.metric
def before_test(self):
if self._is_stage_to_log:
if self._is_stage_to_compute:
self.metric.reset()
def after_test_iter(self, logits, label, *args):
if self._is_stage_to_log:
if self._is_stage_to_compute:
self.metric.update(logits, label)
@HOOKS.register_module
class Accuracy2p5DHook(MetricHook):
def __init__(self, trainer: Trainer, priority: int = 10):
def __init__(self, trainer: Trainer, priority: int = 0):
super().__init__(trainer, priority)
if self._is_stage_to_log:
if self._is_stage_to_compute:
self.metric = Accuracy2p5D(epoch_only=True)
# register the metric
@ -113,11 +117,11 @@ class Accuracy2p5DHook(MetricHook):
self.metric.__class__.__name__] = self.metric
def before_test(self):
if self._is_stage_to_log:
if self._is_stage_to_compute:
self.metric.reset()
def after_test_iter(self, logits, label, *args):
if self._is_stage_to_log:
if self._is_stage_to_compute:
self.metric.update(logits, label)
@ -138,7 +142,7 @@ class Accuracy3DHook(MetricHook):
priority: int = 10):
super().__init__(trainer, priority)
if self._is_stage_to_log:
if self._is_stage_to_compute:
self.metric = Accuracy3D(epoch_only=True,
input_parallel_mode=input_parallel_mode,
weight_parallel_mode=weight_parallel_mode)
@ -148,11 +152,11 @@ class Accuracy3DHook(MetricHook):
self.metric.__class__.__name__] = self.metric
def before_test(self):
if self._is_stage_to_log:
if self._is_stage_to_compute:
self.metric.reset()
def after_test_iter(self, logits, label, *args):
if self._is_stage_to_log:
if self._is_stage_to_compute:
self.metric.update(logits, label)
@ -166,10 +170,10 @@ class AccuracyHook(MetricHook):
:type priority: int
"""
def __init__(self, trainer: Trainer, priority: int = 10):
def __init__(self, trainer: Trainer, priority: int = 0):
super().__init__(trainer, priority)
if self._is_stage_to_log:
if self._is_stage_to_compute:
self.metric = Accuracy(epoch_only=True)
# register the metric
@ -177,9 +181,9 @@ class AccuracyHook(MetricHook):
self.metric.__class__.__name__] = self.metric
def before_test(self):
if self._is_stage_to_log:
if self._is_stage_to_compute:
self.metric.reset()
def after_test_iter(self, logits, label, *args):
if self._is_stage_to_log:
if self._is_stage_to_compute:
self.metric.update(logits, label)

View File

@ -126,6 +126,33 @@ class Loss(Metric):
return a < b
class LearningRate(Metric):
"""A metric collector for learning rate.
:param epoch_only: Whether the metric only read for the full epoch
:type epoch_only: bool
"""
def __init__(self, epoch_only: bool, initial_lr: float = 0.):
super().__init__(epoch_only=epoch_only)
self.lr = 0.
def reset(self) -> None:
pass
def update(self, lr) -> None:
self.lr = lr
def get_last_step_value(self):
return self.lr
def get_accumulated_value(self):
return self.lr
def is_better(a, b) -> bool:
pass
class Accuracy(Metric):
"""A metric collector for accuracy. It only works for classification
tasks.

View File

@ -5,9 +5,9 @@ from typing import Tuple
import torch
from .context import Config
from .context.parallel_mode import ParallelMode
from .core import global_context as gpc
from colossalai.context import Config
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
__all__ = [
'get_checkpoint_path',

View File

@ -27,7 +27,7 @@ def sync_model_param_in_dp(model):
:param model: A pyTorch nn.model on whose parameters you check the consistency
'''
if gpc.is_initialized(ParallelMode.DATA) and gpc.get_world_size(ParallelMode.DATA) > 2:
if gpc.is_initialized(ParallelMode.DATA) and gpc.get_world_size(ParallelMode.DATA) > 1:
for param in model.parameters():
ranks = gpc.get_ranks_in_group(ParallelMode.DATA)
dist.broadcast(param, src=ranks[0], group=gpc.get_group(ParallelMode.DATA))

View File

@ -4,6 +4,7 @@ import os
IMG_SIZE = 224
BATCH_SIZE = 256
NUM_EPOCHS = 100
model = dict(
type='VanillaResNet',
@ -67,8 +68,6 @@ loss = dict(
type='CrossEntropyLoss'
)
max_epochs = 100
from colossalai.engine import AMP_TYPE
fp16 = dict(

View File

@ -1,21 +1,20 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
NUM_EPOCH = int
model = dict()
train_data = dict()
test_data = dict()
optimizer = dict()
loss = dict()
lr_scheduler = dict()
fp16 = dict()
zero = dict()
gradient_handler = []
parallel = dict()
num_epochs = int
num_steps = int
hooks = []
cudnn_benchmark = True
cudnn_deterministic = False

View File

@ -8,10 +8,11 @@ BATCH_SIZE = 512
IMG_SIZE = 32
PATCH_SIZE = 4
DIM = 512
NUM_ATTENTION_HEADS = 8
NUM_ATTENTION_HEADS = 2
SUMMA_DIM = 2
NUM_CLASSES = 10
DEPTH = 6
DEPTH = 1
NUM_EPOCHS = 60
train_data = dict(
dataset=dict(
@ -127,14 +128,22 @@ hooks = [
dict(type='LogMetricByEpochHook'),
dict(type='Accuracy2DHook'),
dict(type='LossHook'),
dict(type='TensorboardHook', log_dir='./tfb_logs'),
dict(
type='LRSchedulerHook',
by_epoch=True,
lr_scheduler_cfg=dict(
type='LinearWarmupLR',
warmup_steps=5
)
),
dict(type='TensorboardHook', log_dir='./tb_logs'),
# dict(type='SaveCheckpointHook', interval=5, checkpoint_dir='./ckpt'),
# dict(type='LoadCheckpointHook', epoch=20, checkpoint_dir='./ckpt')
]
parallel = dict(
pipeline=dict(size=1),
tensor=dict(size=4, mode='2d'),
tensor=dict(size=1, mode='2d'),
)
# for fp16 training
@ -144,17 +153,11 @@ parallel = dict(
# initial_scale=2 ** 8
# )
lr_scheduler = dict(
type='LinearWarmupLR',
warmup_epochs=5
)
# only needed when pipeline parallel is used
# schedule = dict(
# num_microbatches=8
# )
num_epochs = 60
logging = dict(
root_path='./logs'

View File

@ -14,6 +14,7 @@ except:
BATCH_SIZE = 512
IMG_SIZE = 32
NUM_EPOCHS = 60
train_data = dict(
dataset=dict(
@ -83,6 +84,14 @@ hooks = [
),
dict(type='LossHook'),
dict(type='TensorboardHook', log_dir='./tfb_logs'),
dict(
type='LRSchedulerHook',
by_epoch=True,
lr_scheduler_cfg=dict(
type='LinearWarmupLR',
warmup_steps=5
)
),
# dict(type='SaveCheckpointHook', interval=5, checkpoint_dir='./ckpt'),
# dict(type='LoadCheckpointHook', epoch=20, checkpoint_dir='./ckpt')
]
@ -97,13 +106,6 @@ fp16 = dict(
initial_scale=2 ** 8
)
lr_scheduler = dict(
type='LinearWarmupLR',
warmup_epochs=5
)
num_epochs = 60
logging = dict(
root_path='./logs'
)

View File

@ -0,0 +1,5 @@
colossalai.engine.amp.amp\_type
===============================
.. automodule:: colossalai.engine.amp.amp_type
:members:

View File

@ -0,0 +1,5 @@
colossalai.engine.amp.grad\_scaler
==================================
.. automodule:: colossalai.engine.amp.grad_scaler
:members:

View File

@ -0,0 +1,12 @@
colossalai.engine.amp
=====================
.. automodule:: colossalai.engine.amp
:members:
.. toctree::
:maxdepth: 2
colossalai.engine.amp.amp_type
colossalai.engine.amp.grad_scaler

View File

@ -1,5 +0,0 @@
colossalai.engine.amp\_type
===========================
.. automodule:: colossalai.engine.amp_type
:members:

View File

@ -7,11 +7,6 @@ colossalai.engine
.. toctree::
:maxdepth: 2
colossalai.engine.amp
colossalai.engine.gradient_handler
colossalai.engine.schedule
.. toctree::
:maxdepth: 2
colossalai.engine.amp_type

View File

@ -21,7 +21,6 @@ colossalai
.. toctree::
:maxdepth: 2
colossalai.checkpointing
colossalai.constants
colossalai.core
colossalai.initialize

View File

@ -0,0 +1,5 @@
colossalai.utils.checkpointing
==============================
.. automodule:: colossalai.utils.checkpointing
:members:

View File

@ -9,6 +9,7 @@ colossalai.utils
:maxdepth: 2
colossalai.utils.activation_checkpoint
colossalai.utils.checkpointing
colossalai.utils.common
colossalai.utils.cuda
colossalai.utils.memory

View File

@ -17,38 +17,40 @@ parallel = dict(
)
```
The name of the dictionary variable should be **parallel**. All the arguments even **parallel** itself are optional and data,
pipeline, tensor parallel size will be set to defaulted value 1. The value of data, pipeline and tensor can be a int
representing the size of specific parallel dimension or a dictionary with a key called "size". The key "mode"
The name of the dictionary variable should be **parallel**. All the arguments even **parallel** itself are optional and
data, pipeline, tensor parallel size will be set to defaulted value 1. The value of data, pipeline and tensor can be a
int representing the size of specific parallel dimension or a dictionary with a key called "size". The key "mode"
represents the way of tensor parallelism.
## Data Parallel
Data parallel is the most common way to distribute your training task by splitting data into several shards and train
on a single shard on each device. The configuration for data parallel is detected automatically and set for you. You do
not have to explicitly set them in your configurations. When data parallel size is larger than 1, Colossal-AI automatically
Data parallel is the most common way to distribute your training task by splitting data into several shards and train on
a single shard on each device. The configuration for data parallel is detected automatically and set for you. You do not
have to explicitly set them in your configurations. When data parallel size is larger than 1, Colossal-AI automatically
adds the distributed data sampler to the dataloader to shard the dataset.
## 1D, 2D, 2.5D and 3D Parallel
To enable hybrid parallelism, we provide an array of tensor parallelism. We provide the list of papers which match each
To enable hybrid parallelism, we provide an array of tensor parallelism. We provide the list of papers which match each
tensor parallel method. These parallel modes need to work with the distributed layers provided by Colossal-AI.
- 1D: [Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism](https://arxiv.org/abs/1909.08053)
-
1D: [Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism](https://arxiv.org/abs/1909.08053)
- 2D: [An Efficient 2D Method for Training Super-Large Deep Learning Models](https://arxiv.org/abs/2104.05343)
2D parallel relies on the SUMMA matrix multiplication algorithm and splits the input data,
model weights and layer outputs along two different dimensions. The tensor chunks are distributed over a 2D mesh of $P = N^2$
devices where $N$ is the number of tensor chunks in a single dimension.
2D parallel relies on the SUMMA matrix multiplication algorithm and splits the input data, model weights and layer
outputs along two different dimensions. The tensor chunks are distributed over a 2D mesh of $P = N^2$ devices where
$N$ is the number of tensor chunks in a single dimension.
- 2.5D: [2.5-dimensional distributed model training](https://arxiv.org/abs/2105.14500)
Inspired by the 2.5D matrix multiplication algorithm, 2.5D parallel introduces a novel tensor parallelism which further
parallelizes 2D tensor parallelism. An amount of $P = N^2 d$ processors are arranged into $d$ layers,
where each layer performs matrix multiplication operations independently with a dimension $N$.
Inspired by the 2.5D matrix multiplication algorithm, 2.5D parallel introduces a novel tensor parallelism which
further parallelizes 2D tensor parallelism. An amount of $P = N^2 d$ processors are arranged into $d$ layers, where
each layer performs matrix multiplication operations independently with a dimension $N$.
- 3D: [Maximizing Parallelism in Distributed Training for Huge Neural Networks](https://arxiv.org/abs/2105.14450)
We also introduce a 3D tensor parallelism that parallelizes neural networks on a 3D processor cube. This method achieves
the optimal, $O(P^{1/3})$ communication overhead on $P$ processors, while both computation and memory usage are evenly distributed
through optimized load balancing of parameters as well as activations.
We also introduce a 3D tensor parallelism that parallelizes neural networks on a 3D processor cube. This method
achieves the optimal, $O(P^{1/3})$ communication overhead on $P$ processors, while both computation and memory usage
are evenly distributed through optimized load balancing of parameters as well as activations.
```python
# 1D parallel
@ -78,12 +80,12 @@ parallel = dict(
## Pipeline Parallel (experimental)
Pipeline parallelism is to split the model into several partitions by layer. For example, let's assume we have a simple
model which consists of two linear layer. We have two GPUs, and we can allocate the first linear layer to the first GPU
Pipeline parallelism is to split the model into several partitions by layer. For example, let's assume we have a simple
model which consists of two linear layer. We have two GPUs, and we can allocate the first linear layer to the first GPU
and the second layer to the second GPU. This example of course wastes the computing resources and is only to demonstrate
the idea of pipeline parallelism.
the idea of pipeline parallelism.
As PyTorch is based on dynamic computation graph, the computation flow is not known until execution. To support pipeline
As PyTorch is based on dynamic computation graph, the computation flow is not known until execution. To support pipeline
parallelism in PyTorch, you may need to add one more attribute, `layers_cfg` in your model class which tells Colossal-AI
the sequence of execution. One example you can refer is `colossalai.nn.model.VanillaResNet`.
@ -192,9 +194,9 @@ class VanillaResNet(BaseModel):
]
```
You can set the number of pipeline stages in your configuration file. When pipeline size is larger than 1, Colossal-AI
will automatically creates the pipeline schedule which defines the forward and backward step. You can specify how many microbatches
to run in each step in the `schedule` configuration.
You can set the number of pipeline stages in your configuration file. When pipeline size is larger than 1, Colossal-AI
will automatically creates the pipeline schedule which defines the forward and backward step. You can specify how many
microbatches to run in each step in the `schedule` configuration.
```python
parallel = dict(
@ -206,10 +208,11 @@ schedule = dict(
num_microbatches = 4 # set the number of microbatches per step
)
```
This feature is still in development and is only experimental for now.
## Sequence Parallel (experimental)
Sequence parallel is to support long-sequence modelling such as document-level text understanding and medical imaging.
This method is proposed in [Sequence Parallelism: Making 4D Parallelism Possible](https://arxiv.org/abs/2105.13120).
Sequence parallel is to support long-sequence modelling such as document-level text understanding and medical imaging.
This method is proposed in [Sequence Parallelism: Making 4D Parallelism Possible](https://arxiv.org/abs/2105.13120).
This feature is still in development and is only experimental for now.

View File

@ -1,8 +1,8 @@
# Quick demo
Colossal-AI is an integrated large-scale deep learning system with efficient parallelization techniques. The system
can accelerate model training on distributed systems with multiple GPUs by applying parallelization techniques. The
system can also run on systems with only one GPU. Quick demos showing how to use Colossal-AI are given below.
Colossal-AI is an integrated large-scale deep learning system with efficient parallelization techniques. The system can
accelerate model training on distributed systems with multiple GPUs by applying parallelization techniques. The system
can also run on systems with only one GPU. Quick demos showing how to use Colossal-AI are given below.
## Single GPU
@ -32,25 +32,17 @@ realizes the training process.
```python
import colossalai
from colossalai.core import global_context as gpc
from colossalai.engine import Engine
from colossalai.logging import get_global_dist_logger
from colossalai.trainer import Trainer
def run_trainer():
model, train_dataloader, test_dataloader, criterion, optimizer, schedule, lr_scheduler = colossalai.initialize()
engine, train_dataloader, test_dataloader = colossalai.initialize()
logger = get_global_dist_logger()
schedule.data_sync = False
engine = Engine(
model=model,
criterion=criterion,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
schedule=schedule
)
logger.info("engine is built", ranks=[0])
trainer = Trainer(engine=engine,
hooks_cfg=gpc.config.hooks,
verbose=True)
logger.info("trainer is built", ranks=[0])
@ -58,11 +50,13 @@ def run_trainer():
trainer.fit(
train_dataloader=train_dataloader,
test_dataloader=test_dataloader,
max_epochs=gpc.config.num_epochs,
epochs=gpc.config.num_epochs,
hooks_cfg=gpc.config.hooks,
display_progress=True,
test_interval=2
)
if __name__ == '__main__':
run_trainer()
```
@ -72,9 +66,9 @@ Zoo. The detailed substitution process is elaborated [here](model.md).
## Features
Colossal-AI provides a collection of parallel training components for you. We aim to support you with your development of
distributed deep learning models just like how you write single-GPU deep learning models. We provide friendly tools to
kickstart distributed training in a few lines.
Colossal-AI provides a collection of parallel training components for you. We aim to support you with your development
of distributed deep learning models just like how you write single-GPU deep learning models. We provide friendly tools
to kickstart distributed training in a few lines.
- [Data Parallelism](parallelization.md)
- [Pipeline Parallelism](parallelization.md)

View File

@ -4,40 +4,36 @@ Colossal-AI是一个大规模深度学习系统其中包含高效的并行技
## 单GPU系统
在带有GPU的非分布式系统上进行模型训练时Colossal-AI可以达到当前的基线效率。[这里](https://colab.research.google.com/drive/1fJnqqFzPuzZ_kn1lwCpG2nh3l2ths0KE?usp=sharing#scrollTo=cQ_y7lBG09LS)我们给出一个Google Colab示例展现如何使用Colossal-AI与CIFAR10数据集在非分布式系统上训练一个LeNet模型。
在带有GPU的非分布式系统上进行模型训练时Colossal-AI可以达到当前的基线效率。[这里](https://colab.research.google.com/drive/1fJnqqFzPuzZ_kn1lwCpG2nh3l2ths0KE?usp=sharing#scrollTo=cQ_y7lBG09LS)我们给出一个Google
Colab示例展现如何使用Colossal-AI与CIFAR10数据集在非分布式系统上训练一个LeNet模型。
## 多GPU系统
在多GPU的分布式系统上训练深度学习模型时Colossal-AI可以使用高效的并行技术来显著地加速训练过程这些技术将在下面的[并行技术](parallelization.md)章节中被详述。下面的代码将在拥有四个GPU的分布式系统上训练一个ViT模型其中`HOST`变量为您分布式系统的IP地址。请注意下面的代码使用了[Slurm](https://slurm.schedmd.com/documentation.html)作业调度系统。
在多GPU的分布式系统上训练深度学习模型时Colossal-AI可以使用高效的并行技术来显著地加速训练过程这些技术将在下面的[并行技术](parallelization.md)
章节中被详述。下面的代码将在拥有四个GPU的分布式系统上训练一个ViT模型其中`HOST`
变量为您分布式系统的IP地址。请注意下面的代码使用了[Slurm](https://slurm.schedmd.com/documentation.html)作业调度系统。
```bash
HOST=xxx.xxx.xxx.xxx srun ./scripts/slurm_dist_train.sh ./examples/run_trainer.py ./configs/vit/vit_2d.py
```
`./configs/vit/vit_2d.py`是一个[配置文件](config.md)Colossal-AI使用配置文件来定义训练过程中需要用到的参数比如模型类型、数据集、以及优化器、学习率调度器等。您可以通过编写配置文件的方式来训练不同的模型。`./examples/run_trainer.py`是一个标准的训练脚本,具体代码已经附在下面。该脚本可以读入配置文件中的训练参数并训练模型。
`./configs/vit/vit_2d.py`是一个[配置文件](config.md)
Colossal-AI使用配置文件来定义训练过程中需要用到的参数比如模型类型、数据集、以及优化器、学习率调度器等。您可以通过编写配置文件的方式来训练不同的模型。`./examples/run_trainer.py`
是一个标准的训练脚本,具体代码已经附在下面。该脚本可以读入配置文件中的训练参数并训练模型。
```python
import colossalai
from colossalai.core import global_context as gpc
from colossalai.engine import Engine
from colossalai.logging import get_global_dist_logger
from colossalai.trainer import Trainer
def run_trainer():
model, train_dataloader, test_dataloader, criterion, optimizer, schedule, lr_scheduler = colossalai.initialize()
engine, train_dataloader, test_dataloader = colossalai.initialize()
logger = get_global_dist_logger()
schedule.data_sync = False
engine = Engine(
model=model,
criterion=criterion,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
schedule=schedule
)
logger.info("engine is built", ranks=[0])
trainer = Trainer(engine=engine,
hooks_cfg=gpc.config.hooks,
verbose=True)
logger.info("trainer is built", ranks=[0])
@ -45,11 +41,13 @@ def run_trainer():
trainer.fit(
train_dataloader=train_dataloader,
test_dataloader=test_dataloader,
max_epochs=gpc.config.num_epochs,
epochs=gpc.config.num_epochs,
hooks_cfg=gpc.config.hooks,
display_progress=True,
test_interval=2
)
if __name__ == '__main__':
run_trainer()
```

View File

@ -2,9 +2,9 @@
## Build your engine
To better understand how `Engine` class works, let's start from the conception of the process function in common engines. The process function
usually controls the behavior over a batch of a dataset, `Engine` class just controls the process function. Here we give a standard process
function in the following code block.
To better understand how `Engine` class works, let's start from the conception of the process function in common
engines. The process function usually controls the behavior over a batch of a dataset, `Engine` class just controls the
process function. Here we give a standard process function in the following code block.
```python
def process_function(dataloader, model, criterion, optim):
@ -16,32 +16,33 @@ def process_function(dataloader, model, criterion, optim):
optim.setp()
```
In `ignite.engine` or `keras.engine`, the process function is always provided by users. However, it is tricky for users to write their own process
functions for pipeline parallelism. Aiming at offering accessible hybrid parallelism for users, we provide the powerful `Engine` class. This class
enables pipeline parallelism and offers one-forward-one-backward non-interleaving strategy. Also, you can use pre-defined learning rate scheduler
in the `Engine` class to adjust learning rate during training.
In `ignite.engine` or `keras.engine`, the process function is always provided by users. However, it is tricky for users
to write their own process functions for pipeline parallelism. Aiming at offering accessible hybrid parallelism for
users, we provide the powerful `Engine` class. This class enables pipeline parallelism and offers
one-forward-one-backward non-interleaving strategy. Also, you can use pre-defined learning rate scheduler in
the `Engine` class to adjust learning rate during training.
In order to build your engine, just set variables `model`, `criterion`, `optimizer`, `lr_scheduler` and `schedule`. The following code block provides
an example.
In order to build your engine, just set variables `model`, `criterion`, `optimizer`, `lr_scheduler` and `schedule`. The
following code block provides an example. **The engine is automatically created from the config file for you if you
start with `colossalai.initialize`.**
```python
import torch
import torch.nn as nn
import torchvision.models as models
import colossalai
from colossalai.engine import Engine
model = models.resnet18()
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model)
lr_scheduler = colossalai.nn.lr_scheduler.CosineAnnealingLR(optimizer, 1000)
schedule = colossalai.engine.schedule.NoPipelineSchedule()
optimizer = torch.optim.Adam(model.parameters())
schedule = colossalai.engine.NoPipelineSchedule()
MyEngine = Engine(
model=model,
criterion=criterion,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
schedule=schedule
step_schedule=schedule
)
```
@ -51,21 +52,24 @@ More information regarding the class can be found in the API references.
### Overview
To learn how to customize a trainer which meets your needs, let's first give a look at the `Trainer` class. We highly recommend that you read *Get Started*
To learn how to customize a trainer which meets your needs, let's first give a look at the `Trainer` class. We highly
recommend that you read *Get Started*
section and *Build your engine* first.
The `Trainer` class enables researchers and engineers to use our system more conveniently. Instead of having to write your own scripts, you can simply
construct your own trainer by calling the `Trainer` class, just like what we did in the following code block.
The `Trainer` class enables researchers and engineers to use our system more conveniently. Instead of having to write
your own scripts, you can simply construct your own trainer by calling the `Trainer` class, just like what we did in the
following code block.
```python
MyTrainer = Trainer(MyEngine)
MyTrainer = Trainer(my_engine)
```
After that, you can use the `fit` method to train or evaluate your model. In order to make our `Trainer` class even more powerful, we incorporate a set of
handy tools to the class. For example, you can monitor or record the running states and metrics which indicate the current performance of the model. These
functions are realized by hooks. The `BasicHook` class allows you to execute your hook functions at specified time. We have already created some practical
hooks for you, as listed below. What you need to do is just picking the right ones which suit your needs. Detailed descriptions of the class can be found
in the API references.
After that, you can use the `fit` method to train or evaluate your model. In order to make our `Trainer` class even more
powerful, we incorporate a set of handy tools to the class. For example, you can monitor or record the running states
and metrics which indicate the current performance of the model. These functions are realized by hooks. The `BasicHook`
class allows you to execute your hook functions at specified time. We have already created some practical hooks for you,
as listed below. What you need to do is just picking the right ones which suit your needs. Detailed descriptions of the
class can be found in the API references.
```python
hooks = [
@ -80,18 +84,21 @@ hooks = [
]
```
These hook functions will record metrics, elapsed time and memory usage and write them to log after each epoch. Besides, they print the current loss and
accuracy to let users monitor the performance of the model.
These hook functions will record metrics, elapsed time and memory usage and write them to log after each epoch. Besides,
they print the current loss and accuracy to let users monitor the performance of the model.
### Hook
If you have your specific needs, feel free to extend our `BaseHook` class to add your own functions, or our `MetricHook` class to write a metric collector.
These hook functions can be called at twelve timing in the trainer's life cycle. Besides, you can define the priorities of all hooks to arrange the execution order of them.
More information can be found in the API references.
If you have your specific needs, feel free to extend our `BaseHook` class to add your own functions, or our `MetricHook`
class to write a metric collector. These hook functions can be called at twelve timing in the trainer's life cycle.
Besides, you can define the priorities of all hooks to arrange the execution order of them. More information can be
found in the API references.
### Metric
You can write your own metrics by extending our `Metric` class. It should be used with the `MetricHook` class. When your write your own metric hooks, please set
the priority carefully and make sure the hook is called before other hooks which might require the results of the metric hook.
You can write your own metrics by extending our `Metric` class. It should be used with the `MetricHook` class. When your
write your own metric hooks, please set the priority carefully and make sure the hook is called before other hooks which
might require the results of the metric hook.
We've already provided some metric hooks and we store metric objects in `runner.states['metrics']`. It is a dictionary and metrics can be accessed by their names.
We've already provided some metric hooks and we store metric objects in `runner.states['metrics']`. It is a dictionary
and metrics can be accessed by their names.

View File

@ -14,28 +14,30 @@ def process_function(dataloader, model, criterion, optim):
optim.setp()
```
在`ignite.engine`与`keras.engine`中,进程函数需要由用户提供,然而,用户很难为流水线并行编写进程函数。为了向用户提供方便的混合并行,我们提供了具备强大功能的`Engine`类,该类支持流水线并行,并提供前向传播后向传播不交织的策略。同时,您可以在`Engine`类中使用您事先定义好的学习率调度器来在训练过程中调整学习率。
在`ignite.engine`与`keras.engine`中,进程函数需要由用户提供,然而,用户很难为流水线并行编写进程函数。为了向用户提供方便的混合并行,我们提供了具备强大功能的`Engine`
类,该类支持流水线并行,并提供前向传播后向传播不交织的策略。同时,您可以在`Engine`类中使用您事先定义好的学习率调度器来在训练过程中调整学习率。
您在构造引擎时只需要定义`model`、`criterion`、`optimizer`、`lr_scheduler`与`schedule`等变量即可,下面的代码块给出了一个这样的例子。
**如果你使用`colossalai.initialize`的话engine会从config文件里自动构建。**
```python
import torch
import torch.nn as nn
import torchvision.models as models
import colossalai
from colossalai.engine import Engine
model = models.resnet18()
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model)
lr_scheduler = colossalai.nn.lr_scheduler.CosineAnnealingLR(optimizer, 1000)
schedule = colossalai.engine.schedule.NoPipelineSchedule()
schedule = colossalai.engine.NoPipelineSchedule()
MyEngine = Engine(
model=model,
criterion=criterion,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
schedule=schedule
step_schedule=schedule
)
```
@ -48,10 +50,12 @@ MyEngine = Engine(
`Trainer`类旨在让科研工作者和工程师更加方便地使用我们的系统,您不需要自己写脚本,只需要调用`Trainer`类来构造您的训练器即可,就像下面的代码块中所做的。
```python
MyTrainer = Trainer(MyEngine)
MyTrainer = Trainer(my_trainer)
```
在此之后,您可以使用`fit`方法来训练或调用您的模型。除此之外,为了让我们的`Trainer`类拥有更强大的功能,我们加入了一系列方便您使用的工具。例如,您可以在训练过程中持续监测并记录模型目前的运行状态和表现,这些功能都是通过钩子函数来实现的。我们提供的`BasicHook`类让您可以在指定时间执行您的钩子函数。如下方的代码块所示我们事先为您定义好了一些实用的钩子函数您需要做的就是找到符合您需求的钩子函数。更多该类的相关信息可以在API信息中找到。
在此之后,您可以使用`fit`方法来训练或调用您的模型。除此之外,为了让我们的`Trainer`
类拥有更强大的功能,我们加入了一系列方便您使用的工具。例如,您可以在训练过程中持续监测并记录模型目前的运行状态和表现,这些功能都是通过钩子函数来实现的。我们提供的`BasicHook`
类让您可以在指定时间执行您的钩子函数。如下方的代码块所示我们事先为您定义好了一些实用的钩子函数您需要做的就是找到符合您需求的钩子函数。更多该类的相关信息可以在API信息中找到。
```python
hooks = [
@ -70,7 +74,8 @@ hooks = [
### 钩子函数
如果您有个性化需求,您可以继承我们的`BaseHook`类并添加您的钩子函数,或者继承我们的`MetricHook`来编写您需要的度量标准。这些钩子函数可以在`Trainer`生命周期的12个时间点被执行。更多该类的相关信息可以在API信息中找到。
如果您有个性化需求,您可以继承我们的`BaseHook`类并添加您的钩子函数,或者继承我们的`MetricHook`来编写您需要的度量标准。这些钩子函数可以在`Trainer`
生命周期的12个时间点被执行。更多该类的相关信息可以在API信息中找到。
### 度量标准

View File

@ -1,370 +1,370 @@
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "colossal_cifar_demo.ipynb",
"provenance": []
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
},
"accelerator": "GPU"
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "colossal_cifar_demo.ipynb",
"provenance": []
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "uhrbvVEh2iJd"
},
"source": [
"# Train an image classifier\n"
]
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
},
"accelerator": "GPU"
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "uhrbvVEh2iJd"
},
"source": [
"# Train an image classifier\n"
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "vP7LvCpG23a2",
"outputId": "b37f7203-8a02-4736-c527-603f2bb34d7d"
},
"source": [
"!pip install ColossalAI deepspeed"
],
"execution_count": null,
"outputs": [
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "vP7LvCpG23a2",
"outputId": "b37f7203-8a02-4736-c527-603f2bb34d7d"
},
"source": [
"!pip install ColossalAI deepspeed"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Requirement already satisfied: ColossalAI in /usr/local/lib/python3.7/dist-packages (0.1)\n",
"Requirement already satisfied: deepspeed in /usr/local/lib/python3.7/dist-packages (0.5.4)\n",
"Requirement already satisfied: packaging in /usr/local/lib/python3.7/dist-packages (from deepspeed) (21.0)\n",
"Requirement already satisfied: triton in /usr/local/lib/python3.7/dist-packages (from deepspeed) (1.1.1)\n",
"Requirement already satisfied: tqdm in /usr/local/lib/python3.7/dist-packages (from deepspeed) (4.62.3)\n",
"Requirement already satisfied: numpy in /usr/local/lib/python3.7/dist-packages (from deepspeed) (1.19.5)\n",
"Requirement already satisfied: tensorboardX==1.8 in /usr/local/lib/python3.7/dist-packages (from deepspeed) (1.8)\n",
"Requirement already satisfied: ninja in /usr/local/lib/python3.7/dist-packages (from deepspeed) (1.10.2.2)\n",
"Requirement already satisfied: torch in /usr/local/lib/python3.7/dist-packages (from deepspeed) (1.9.0+cu111)\n",
"Requirement already satisfied: psutil in /usr/local/lib/python3.7/dist-packages (from deepspeed) (5.4.8)\n",
"Requirement already satisfied: protobuf>=3.2.0 in /usr/local/lib/python3.7/dist-packages (from tensorboardX==1.8->deepspeed) (3.17.3)\n",
"Requirement already satisfied: six in /usr/local/lib/python3.7/dist-packages (from tensorboardX==1.8->deepspeed) (1.15.0)\n",
"Requirement already satisfied: pyparsing>=2.0.2 in /usr/local/lib/python3.7/dist-packages (from packaging->deepspeed) (2.4.7)\n",
"Requirement already satisfied: typing-extensions in /usr/local/lib/python3.7/dist-packages (from torch->deepspeed) (3.7.4.3)\n",
"Requirement already satisfied: filelock in /usr/local/lib/python3.7/dist-packages (from triton->deepspeed) (3.3.0)\n"
]
}
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "UVKEurtS4SFS",
"outputId": "99fb6050-5da7-4f27-b4eb-9b3ccf830efb"
},
"source": [
"import colossalai\n",
"from colossalai.engine import Engine, NoPipelineSchedule\n",
"from colossalai.trainer import Trainer\n",
"from colossalai.context import Config\n",
"import torch"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Please install apex to use FP16 Optimizer\n",
"Apex should be installed to use the FP16 optimizer\n",
"apex is required for mixed precision training\n"
]
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "PpFfhNBD7NSn"
},
"source": [
"First, we should initialize distributed environment. Though we just use single GPU in this example, we still need initialize distributed environment for compatibility. We just consider the simplest case here, so we just set the number of parallel processes to 1."
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "8yF7Lc-K7NAS",
"outputId": "01312349-a8b0-4de4-9103-7d1b48e6cc36"
},
"source": [
"parallel_cfg = Config(dict(parallel=dict(\n",
" data=dict(size=1),\n",
" pipeline=dict(size=1),\n",
" tensor=dict(size=1, mode=None),\n",
")))\n",
"colossalai.init_dist(config=parallel_cfg,\n",
" local_rank=0,\n",
" world_size=1,\n",
" host='127.0.0.1',\n",
" port=8888,\n",
" backend='nccl')"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stderr",
"text": [
"colossalai - torch.distributed.distributed_c10d - 2021-10-15 03:27:51,596 INFO: Added key: store_based_barrier_key:1 to store for rank: 0\n",
"colossalai - torch.distributed.distributed_c10d - 2021-10-15 03:27:51,598 INFO: Rank 0: Completed store-based barrier for 1 nodes.\n",
"colossalai - torch.distributed.distributed_c10d - 2021-10-15 03:27:51,602 INFO: Added key: store_based_barrier_key:2 to store for rank: 0\n",
"colossalai - torch.distributed.distributed_c10d - 2021-10-15 03:27:51,605 INFO: Rank 0: Completed store-based barrier for 1 nodes.\n",
"colossalai - torch.distributed.distributed_c10d - 2021-10-15 03:27:51,608 INFO: Added key: store_based_barrier_key:3 to store for rank: 0\n",
"colossalai - torch.distributed.distributed_c10d - 2021-10-15 03:27:51,610 INFO: Rank 0: Completed store-based barrier for 1 nodes.\n"
]
},
{
"output_type": "stream",
"name": "stdout",
"text": [
"process rank 0 is bound to device 0\n",
"initialized seed on rank 0, numpy: 1024, python random: 1024, ParallelMode.DATA: 1024, ParallelMode.TENSOR: 1124,the default parallel seed is ParallelMode.DATA.\n"
]
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ppjmMxc_81TK"
},
"source": [
"Load and normalize the CIFAR10 training and test datasets using `colossalai.nn.data`. Note that we have wrapped `torchvision.transforms`, so that we can simply use the config dict to use them."
]
},
{
"cell_type": "code",
"metadata": {
"id": "ZyGhyD47-dUY",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "98bbf2d1-a1c4-4bb4-b6df-600777b1e8f5"
},
"source": [
"transform_cfg = [\n",
" dict(type='ToTensor'),\n",
" dict(type='Normalize',\n",
" mean=[0.4914, 0.4822, 0.4465],\n",
" std=[0.2023, 0.1994, 0.2010]),\n",
"]\n",
"\n",
"batch_size = 128\n",
"\n",
"trainset = colossalai.nn.data.CIFAR10Dataset(transform_cfg, root='./data', train=True)\n",
"trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2)\n",
"\n",
"testset = colossalai.nn.data.CIFAR10Dataset(transform_cfg, root='./data', train=False)\n",
"testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=2)"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Files already downloaded and verified\n",
"Files already downloaded and verified\n"
]
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "NvPbfLLR9NzC"
},
"source": [
"We just define a simple Convolutional Neural Network here."
]
},
{
"cell_type": "code",
"metadata": {
"id": "cQ_y7lBG09LS"
},
"source": [
"import torch.nn as nn\n",
"import torch.nn.functional as F\n",
"\n",
"\n",
"class Net(nn.Module):\n",
" def __init__(self):\n",
" super().__init__()\n",
" self.conv1 = nn.Conv2d(3, 6, 5)\n",
" self.pool = nn.MaxPool2d(2, 2)\n",
" self.conv2 = nn.Conv2d(6, 16, 5)\n",
" self.fc1 = nn.Linear(16 * 5 * 5, 120)\n",
" self.fc2 = nn.Linear(120, 84)\n",
" self.fc3 = nn.Linear(84, 10)\n",
"\n",
" def forward(self, x):\n",
" x = self.pool(F.relu(self.conv1(x)))\n",
" x = self.pool(F.relu(self.conv2(x)))\n",
" x = torch.flatten(x, 1) # flatten all dimensions except batch\n",
" x = F.relu(self.fc1(x))\n",
" x = F.relu(self.fc2(x))\n",
" x = self.fc3(x)\n",
" return x\n",
"\n",
"\n",
"model = Net().cuda()"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "tgsszAmM9dYZ"
},
"source": [
"Define a Loss function and optimizer. And then we use them to initialize `Engine` and `Trainer`. We provide various training / evaluating hooks. In this case, we just use the simplest hooks which can compute and print loss and accuracy."
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "YtaDoCax1BCf",
"outputId": "b33b1641-03d8-4597-c8c2-1a4c1d61e9b0"
},
"source": [
"import torch.optim as optim\n",
"\n",
"criterion = nn.CrossEntropyLoss()\n",
"optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)\n",
"schedule = NoPipelineSchedule()\n",
"engine = Engine(\n",
" model=model,\n",
" criterion=criterion,\n",
" optimizer=optimizer,\n",
" lr_scheduler=None,\n",
" schedule=schedule\n",
" )\n",
"trainer = Trainer(engine=engine,\n",
" hooks_cfg=[dict(type='LossHook'), dict(type='LogMetricByEpochHook'), dict(type='AccuracyHook')],\n",
" verbose=True)"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stderr",
"text": [
"colossalai - rank_0 - 2021-10-15 03:27:56,018 WARNING: No gradient handler is set up, please make sure you do not need to all-reduce the gradients after a training step.\n",
"colossalai - rank_0 - 2021-10-15 03:27:56,024 INFO: build LogMetricByEpochHook for train, priority = 1\n",
"colossalai - rank_0 - 2021-10-15 03:27:56,026 INFO: build LossHook for train, priority = 10\n",
"colossalai - rank_0 - 2021-10-15 03:27:56,029 INFO: build AccuracyHook for train, priority = 10\n"
]
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "_JR2TuvH99Ik"
},
"source": [
"Then we set training configs. We train our model for 10 epochs and it will be evaluated every 1 epoch. Set `display_progress` to `True` to display the training / evaluating progress bar."
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "w-J3IP-J1sfx",
"outputId": "bdb76939-04f1-4124-ce5e-3af44c0d902c"
},
"source": [
"num_epochs = 10\n",
"test_interval = 1\n",
"trainer.fit(\n",
" train_dataloader=trainloader,\n",
" test_dataloader=testloader,\n",
" max_epochs=num_epochs,\n",
" display_progress=True,\n",
" test_interval=test_interval\n",
" )"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stderr",
"text": [
"[Epoch 0 train]: 0%| | 0/391 [00:00<?, ?it/s]/usr/local/lib/python3.7/dist-packages/torch/nn/functional.py:718: UserWarning: Named tensors and all their associated APIs are an experimental feature and subject to change. Please do not use them for anything important until they are released as stable. (Triggered internally at /pytorch/c10/core/TensorImpl.h:1156.)\n",
" return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)\n",
"[Epoch 0 train]: 100%|██████████| 391/391 [00:14<00:00, 26.82it/s]\n",
"colossalai - rank_0 - 2021-10-15 03:28:11,088 INFO: Training - Epoch 1 - LogMetricByEpochHook: Loss = 2.29158\n",
"[Epoch 0 val]: 100%|██████████| 79/79 [00:02<00:00, 28.66it/s]\n",
"colossalai - rank_0 - 2021-10-15 03:28:14,040 INFO: Testing - Epoch 1 - LogMetricByEpochHook: Loss = 2.26517, Accuracy = 0.14820\n",
"[Epoch 1 train]: 100%|██████████| 391/391 [00:14<00:00, 26.31it/s]\n",
"colossalai - rank_0 - 2021-10-15 03:28:29,059 INFO: Training - Epoch 2 - LogMetricByEpochHook: Loss = 2.15763\n",
"[Epoch 1 val]: 100%|██████████| 79/79 [00:02<00:00, 28.50it/s]\n",
"colossalai - rank_0 - 2021-10-15 03:28:32,007 INFO: Testing - Epoch 2 - LogMetricByEpochHook: Loss = 2.00450, Accuracy = 0.27850\n",
"[Epoch 2 train]: 100%|██████████| 391/391 [00:14<00:00, 26.08it/s]\n",
"colossalai - rank_0 - 2021-10-15 03:28:47,167 INFO: Training - Epoch 3 - LogMetricByEpochHook: Loss = 1.85409\n",
"[Epoch 2 val]: 100%|██████████| 79/79 [00:02<00:00, 27.89it/s]\n",
"colossalai - rank_0 - 2021-10-15 03:28:50,168 INFO: Testing - Epoch 3 - LogMetricByEpochHook: Loss = 1.73788, Accuracy = 0.35990\n",
"[Epoch 3 train]: 100%|██████████| 391/391 [00:14<00:00, 26.09it/s]\n",
"colossalai - rank_0 - 2021-10-15 03:29:05,330 INFO: Training - Epoch 4 - LogMetricByEpochHook: Loss = 1.69363\n",
"[Epoch 3 val]: 100%|██████████| 79/79 [00:02<00:00, 28.43it/s]\n",
"colossalai - rank_0 - 2021-10-15 03:29:08,290 INFO: Testing - Epoch 4 - LogMetricByEpochHook: Loss = 1.65005, Accuracy = 0.39350\n",
"[Epoch 4 train]: 100%|██████████| 391/391 [00:15<00:00, 25.97it/s]\n",
"colossalai - rank_0 - 2021-10-15 03:29:23,530 INFO: Training - Epoch 5 - LogMetricByEpochHook: Loss = 1.61387\n",
"[Epoch 4 val]: 100%|██████████| 79/79 [00:02<00:00, 27.75it/s]\n",
"colossalai - rank_0 - 2021-10-15 03:29:26,515 INFO: Testing - Epoch 5 - LogMetricByEpochHook: Loss = 1.57507, Accuracy = 0.42430\n",
"[Epoch 5 train]: 100%|██████████| 391/391 [00:15<00:00, 25.92it/s]\n",
"colossalai - rank_0 - 2021-10-15 03:29:41,764 INFO: Training - Epoch 6 - LogMetricByEpochHook: Loss = 1.55712\n",
"[Epoch 5 val]: 100%|██████████| 79/79 [00:02<00:00, 27.51it/s]\n",
"colossalai - rank_0 - 2021-10-15 03:29:44,778 INFO: Testing - Epoch 6 - LogMetricByEpochHook: Loss = 1.53242, Accuracy = 0.43700\n",
"[Epoch 6 train]: 100%|██████████| 391/391 [00:14<00:00, 26.13it/s]\n",
"colossalai - rank_0 - 2021-10-15 03:29:59,927 INFO: Training - Epoch 7 - LogMetricByEpochHook: Loss = 1.51618\n",
"[Epoch 6 val]: 100%|██████████| 79/79 [00:02<00:00, 28.31it/s]\n",
"colossalai - rank_0 - 2021-10-15 03:30:02,884 INFO: Testing - Epoch 7 - LogMetricByEpochHook: Loss = 1.49720, Accuracy = 0.45430\n",
"[Epoch 7 train]: 100%|██████████| 391/391 [00:14<00:00, 26.23it/s]\n",
"colossalai - rank_0 - 2021-10-15 03:30:17,968 INFO: Training - Epoch 8 - LogMetricByEpochHook: Loss = 1.47857\n",
"[Epoch 7 val]: 100%|██████████| 79/79 [00:02<00:00, 27.97it/s]\n",
"colossalai - rank_0 - 2021-10-15 03:30:20,967 INFO: Testing - Epoch 8 - LogMetricByEpochHook: Loss = 1.45808, Accuracy = 0.46320\n",
"[Epoch 8 train]: 100%|██████████| 391/391 [00:14<00:00, 26.11it/s]\n",
"colossalai - rank_0 - 2021-10-15 03:30:36,129 INFO: Training - Epoch 9 - LogMetricByEpochHook: Loss = 1.44656\n",
"[Epoch 8 val]: 100%|██████████| 79/79 [00:02<00:00, 28.18it/s]\n",
"colossalai - rank_0 - 2021-10-15 03:30:39,096 INFO: Testing - Epoch 9 - LogMetricByEpochHook: Loss = 1.44903, Accuracy = 0.46580\n",
"[Epoch 9 train]: 100%|██████████| 391/391 [00:15<00:00, 25.97it/s]\n",
"colossalai - rank_0 - 2021-10-15 03:30:54,342 INFO: Training - Epoch 10 - LogMetricByEpochHook: Loss = 1.41120\n",
"[Epoch 9 val]: 100%|██████████| 79/79 [00:02<00:00, 28.05it/s]\n",
"colossalai - rank_0 - 2021-10-15 03:30:57,332 INFO: Testing - Epoch 10 - LogMetricByEpochHook: Loss = 1.41242, Accuracy = 0.48500\n"
]
}
]
"output_type": "stream",
"name": "stdout",
"text": [
"Requirement already satisfied: ColossalAI in /usr/local/lib/python3.7/dist-packages (0.1)\n",
"Requirement already satisfied: deepspeed in /usr/local/lib/python3.7/dist-packages (0.5.4)\n",
"Requirement already satisfied: packaging in /usr/local/lib/python3.7/dist-packages (from deepspeed) (21.0)\n",
"Requirement already satisfied: triton in /usr/local/lib/python3.7/dist-packages (from deepspeed) (1.1.1)\n",
"Requirement already satisfied: tqdm in /usr/local/lib/python3.7/dist-packages (from deepspeed) (4.62.3)\n",
"Requirement already satisfied: numpy in /usr/local/lib/python3.7/dist-packages (from deepspeed) (1.19.5)\n",
"Requirement already satisfied: tensorboardX==1.8 in /usr/local/lib/python3.7/dist-packages (from deepspeed) (1.8)\n",
"Requirement already satisfied: ninja in /usr/local/lib/python3.7/dist-packages (from deepspeed) (1.10.2.2)\n",
"Requirement already satisfied: torch in /usr/local/lib/python3.7/dist-packages (from deepspeed) (1.9.0+cu111)\n",
"Requirement already satisfied: psutil in /usr/local/lib/python3.7/dist-packages (from deepspeed) (5.4.8)\n",
"Requirement already satisfied: protobuf>=3.2.0 in /usr/local/lib/python3.7/dist-packages (from tensorboardX==1.8->deepspeed) (3.17.3)\n",
"Requirement already satisfied: six in /usr/local/lib/python3.7/dist-packages (from tensorboardX==1.8->deepspeed) (1.15.0)\n",
"Requirement already satisfied: pyparsing>=2.0.2 in /usr/local/lib/python3.7/dist-packages (from packaging->deepspeed) (2.4.7)\n",
"Requirement already satisfied: typing-extensions in /usr/local/lib/python3.7/dist-packages (from torch->deepspeed) (3.7.4.3)\n",
"Requirement already satisfied: filelock in /usr/local/lib/python3.7/dist-packages (from triton->deepspeed) (3.3.0)\n"
]
}
]
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "UVKEurtS4SFS",
"outputId": "99fb6050-5da7-4f27-b4eb-9b3ccf830efb"
},
"source": [
"import colossalai\n",
"from colossalai.engine import Engine, NoPipelineSchedule\n",
"from colossalai.trainer import Trainer\n",
"from colossalai.context import Config\n",
"import torch"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Please install apex to use FP16 Optimizer\n",
"Apex should be installed to use the FP16 optimizer\n",
"apex is required for mixed precision training\n"
]
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "PpFfhNBD7NSn"
},
"source": [
"First, we should initialize distributed environment. Though we just use single GPU in this example, we still need initialize distributed environment for compatibility. We just consider the simplest case here, so we just set the number of parallel processes to 1."
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "8yF7Lc-K7NAS",
"outputId": "01312349-a8b0-4de4-9103-7d1b48e6cc36"
},
"source": [
"parallel_cfg = Config(dict(parallel=dict(\n",
" data=dict(size=1),\n",
" pipeline=dict(size=1),\n",
" tensor=dict(size=1, mode=None),\n",
")))\n",
"colossalai.init_dist(config=parallel_cfg,\n",
" local_rank=0,\n",
" world_size=1,\n",
" host='127.0.0.1',\n",
" port=8888,\n",
" backend='nccl')"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stderr",
"text": [
"colossalai - torch.distributed.distributed_c10d - 2021-10-15 03:27:51,596 INFO: Added key: store_based_barrier_key:1 to store for rank: 0\n",
"colossalai - torch.distributed.distributed_c10d - 2021-10-15 03:27:51,598 INFO: Rank 0: Completed store-based barrier for 1 nodes.\n",
"colossalai - torch.distributed.distributed_c10d - 2021-10-15 03:27:51,602 INFO: Added key: store_based_barrier_key:2 to store for rank: 0\n",
"colossalai - torch.distributed.distributed_c10d - 2021-10-15 03:27:51,605 INFO: Rank 0: Completed store-based barrier for 1 nodes.\n",
"colossalai - torch.distributed.distributed_c10d - 2021-10-15 03:27:51,608 INFO: Added key: store_based_barrier_key:3 to store for rank: 0\n",
"colossalai - torch.distributed.distributed_c10d - 2021-10-15 03:27:51,610 INFO: Rank 0: Completed store-based barrier for 1 nodes.\n"
]
},
{
"output_type": "stream",
"name": "stdout",
"text": [
"process rank 0 is bound to device 0\n",
"initialized seed on rank 0, numpy: 1024, python random: 1024, ParallelMode.DATA: 1024, ParallelMode.TENSOR: 1124,the default parallel seed is ParallelMode.DATA.\n"
]
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ppjmMxc_81TK"
},
"source": [
"Load and normalize the CIFAR10 training and test datasets using `colossalai.nn.data`. Note that we have wrapped `torchvision.transforms`, so that we can simply use the config dict to use them."
]
},
{
"cell_type": "code",
"metadata": {
"id": "ZyGhyD47-dUY",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "98bbf2d1-a1c4-4bb4-b6df-600777b1e8f5"
},
"source": [
"transform_cfg = [\n",
" dict(type='ToTensor'),\n",
" dict(type='Normalize',\n",
" mean=[0.4914, 0.4822, 0.4465],\n",
" std=[0.2023, 0.1994, 0.2010]),\n",
"]\n",
"\n",
"batch_size = 128\n",
"\n",
"trainset = colossalai.nn.data.CIFAR10Dataset(transform_cfg, root='./data', train=True)\n",
"trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2)\n",
"\n",
"testset = colossalai.nn.data.CIFAR10Dataset(transform_cfg, root='./data', train=False)\n",
"testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=2)"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Files already downloaded and verified\n",
"Files already downloaded and verified\n"
]
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "NvPbfLLR9NzC"
},
"source": [
"We just define a simple Convolutional Neural Network here."
]
},
{
"cell_type": "code",
"metadata": {
"id": "cQ_y7lBG09LS"
},
"source": [
"import torch.nn as nn\n",
"import torch.nn.functional as F\n",
"\n",
"\n",
"class Net(nn.Module):\n",
" def __init__(self):\n",
" super().__init__()\n",
" self.conv1 = nn.Conv2d(3, 6, 5)\n",
" self.pool = nn.MaxPool2d(2, 2)\n",
" self.conv2 = nn.Conv2d(6, 16, 5)\n",
" self.fc1 = nn.Linear(16 * 5 * 5, 120)\n",
" self.fc2 = nn.Linear(120, 84)\n",
" self.fc3 = nn.Linear(84, 10)\n",
"\n",
" def forward(self, x):\n",
" x = self.pool(F.relu(self.conv1(x)))\n",
" x = self.pool(F.relu(self.conv2(x)))\n",
" x = torch.flatten(x, 1) # flatten all dimensions except batch\n",
" x = F.relu(self.fc1(x))\n",
" x = F.relu(self.fc2(x))\n",
" x = self.fc3(x)\n",
" return x\n",
"\n",
"\n",
"model = Net().cuda()"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "tgsszAmM9dYZ"
},
"source": [
"Define a Loss function and optimizer. And then we use them to initialize `Engine` and `Trainer`. We provide various training / evaluating hooks. In this case, we just use the simplest hooks which can compute and print loss and accuracy."
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "YtaDoCax1BCf",
"outputId": "b33b1641-03d8-4597-c8c2-1a4c1d61e9b0"
},
"source": [
"import torch.optim as optim\n",
"\n",
"criterion = nn.CrossEntropyLoss()\n",
"optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)\n",
"schedule = NoPipelineSchedule()\n",
"engine = Engine(\n",
" model=model,\n",
" criterion=criterion,\n",
" optimizer=optimizer,\n",
" lr_scheduler=None,\n",
" schedule=schedule\n",
" )\n",
"trainer = Trainer(engine=engine,\n",
" hooks_cfg=[dict(type='LossHook'), dict(type='LogMetricByEpochHook'), dict(type='AccuracyHook')],\n",
" verbose=True)"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stderr",
"text": [
"colossalai - rank_0 - 2021-10-15 03:27:56,018 WARNING: No gradient handler is set up, please make sure you do not need to all-reduce the gradients after a training step.\n",
"colossalai - rank_0 - 2021-10-15 03:27:56,024 INFO: build LogMetricByEpochHook for train, priority = 1\n",
"colossalai - rank_0 - 2021-10-15 03:27:56,026 INFO: build LossHook for train, priority = 10\n",
"colossalai - rank_0 - 2021-10-15 03:27:56,029 INFO: build AccuracyHook for train, priority = 10\n"
]
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "_JR2TuvH99Ik"
},
"source": [
"Then we set training configs. We train our model for 10 epochs and it will be evaluated every 1 epoch. Set `display_progress` to `True` to display the training / evaluating progress bar."
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "w-J3IP-J1sfx",
"outputId": "bdb76939-04f1-4124-ce5e-3af44c0d902c"
},
"source": [
"num_epochs = 10\n",
"test_interval = 1\n",
"trainer.fit(\n",
" train_dataloader=trainloader,\n",
" test_dataloader=testloader,\n",
" max_epochs=num_epochs,\n",
" display_progress=True,\n",
" test_interval=test_interval\n",
" )"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stderr",
"text": [
"[Epoch 0 train]: 0%| | 0/391 [00:00<?, ?it/s]/usr/local/lib/python3.7/dist-packages/torch/nn/functional.py:718: UserWarning: Named tensors and all their associated APIs are an experimental feature and subject to change. Please do not use them for anything important until they are released as stable. (Triggered internally at /pytorch/c10/core/TensorImpl.h:1156.)\n",
" return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)\n",
"[Epoch 0 train]: 100%|██████████| 391/391 [00:14<00:00, 26.82it/s]\n",
"colossalai - rank_0 - 2021-10-15 03:28:11,088 INFO: Training - Epoch 1 - LogMetricByEpochHook: Loss = 2.29158\n",
"[Epoch 0 val]: 100%|██████████| 79/79 [00:02<00:00, 28.66it/s]\n",
"colossalai - rank_0 - 2021-10-15 03:28:14,040 INFO: Testing - Epoch 1 - LogMetricByEpochHook: Loss = 2.26517, Accuracy = 0.14820\n",
"[Epoch 1 train]: 100%|██████████| 391/391 [00:14<00:00, 26.31it/s]\n",
"colossalai - rank_0 - 2021-10-15 03:28:29,059 INFO: Training - Epoch 2 - LogMetricByEpochHook: Loss = 2.15763\n",
"[Epoch 1 val]: 100%|██████████| 79/79 [00:02<00:00, 28.50it/s]\n",
"colossalai - rank_0 - 2021-10-15 03:28:32,007 INFO: Testing - Epoch 2 - LogMetricByEpochHook: Loss = 2.00450, Accuracy = 0.27850\n",
"[Epoch 2 train]: 100%|██████████| 391/391 [00:14<00:00, 26.08it/s]\n",
"colossalai - rank_0 - 2021-10-15 03:28:47,167 INFO: Training - Epoch 3 - LogMetricByEpochHook: Loss = 1.85409\n",
"[Epoch 2 val]: 100%|██████████| 79/79 [00:02<00:00, 27.89it/s]\n",
"colossalai - rank_0 - 2021-10-15 03:28:50,168 INFO: Testing - Epoch 3 - LogMetricByEpochHook: Loss = 1.73788, Accuracy = 0.35990\n",
"[Epoch 3 train]: 100%|██████████| 391/391 [00:14<00:00, 26.09it/s]\n",
"colossalai - rank_0 - 2021-10-15 03:29:05,330 INFO: Training - Epoch 4 - LogMetricByEpochHook: Loss = 1.69363\n",
"[Epoch 3 val]: 100%|██████████| 79/79 [00:02<00:00, 28.43it/s]\n",
"colossalai - rank_0 - 2021-10-15 03:29:08,290 INFO: Testing - Epoch 4 - LogMetricByEpochHook: Loss = 1.65005, Accuracy = 0.39350\n",
"[Epoch 4 train]: 100%|██████████| 391/391 [00:15<00:00, 25.97it/s]\n",
"colossalai - rank_0 - 2021-10-15 03:29:23,530 INFO: Training - Epoch 5 - LogMetricByEpochHook: Loss = 1.61387\n",
"[Epoch 4 val]: 100%|██████████| 79/79 [00:02<00:00, 27.75it/s]\n",
"colossalai - rank_0 - 2021-10-15 03:29:26,515 INFO: Testing - Epoch 5 - LogMetricByEpochHook: Loss = 1.57507, Accuracy = 0.42430\n",
"[Epoch 5 train]: 100%|██████████| 391/391 [00:15<00:00, 25.92it/s]\n",
"colossalai - rank_0 - 2021-10-15 03:29:41,764 INFO: Training - Epoch 6 - LogMetricByEpochHook: Loss = 1.55712\n",
"[Epoch 5 val]: 100%|██████████| 79/79 [00:02<00:00, 27.51it/s]\n",
"colossalai - rank_0 - 2021-10-15 03:29:44,778 INFO: Testing - Epoch 6 - LogMetricByEpochHook: Loss = 1.53242, Accuracy = 0.43700\n",
"[Epoch 6 train]: 100%|██████████| 391/391 [00:14<00:00, 26.13it/s]\n",
"colossalai - rank_0 - 2021-10-15 03:29:59,927 INFO: Training - Epoch 7 - LogMetricByEpochHook: Loss = 1.51618\n",
"[Epoch 6 val]: 100%|██████████| 79/79 [00:02<00:00, 28.31it/s]\n",
"colossalai - rank_0 - 2021-10-15 03:30:02,884 INFO: Testing - Epoch 7 - LogMetricByEpochHook: Loss = 1.49720, Accuracy = 0.45430\n",
"[Epoch 7 train]: 100%|██████████| 391/391 [00:14<00:00, 26.23it/s]\n",
"colossalai - rank_0 - 2021-10-15 03:30:17,968 INFO: Training - Epoch 8 - LogMetricByEpochHook: Loss = 1.47857\n",
"[Epoch 7 val]: 100%|██████████| 79/79 [00:02<00:00, 27.97it/s]\n",
"colossalai - rank_0 - 2021-10-15 03:30:20,967 INFO: Testing - Epoch 8 - LogMetricByEpochHook: Loss = 1.45808, Accuracy = 0.46320\n",
"[Epoch 8 train]: 100%|██████████| 391/391 [00:14<00:00, 26.11it/s]\n",
"colossalai - rank_0 - 2021-10-15 03:30:36,129 INFO: Training - Epoch 9 - LogMetricByEpochHook: Loss = 1.44656\n",
"[Epoch 8 val]: 100%|██████████| 79/79 [00:02<00:00, 28.18it/s]\n",
"colossalai - rank_0 - 2021-10-15 03:30:39,096 INFO: Testing - Epoch 9 - LogMetricByEpochHook: Loss = 1.44903, Accuracy = 0.46580\n",
"[Epoch 9 train]: 100%|██████████| 391/391 [00:15<00:00, 25.97it/s]\n",
"colossalai - rank_0 - 2021-10-15 03:30:54,342 INFO: Training - Epoch 10 - LogMetricByEpochHook: Loss = 1.41120\n",
"[Epoch 9 val]: 100%|██████████| 79/79 [00:02<00:00, 28.05it/s]\n",
"colossalai - rank_0 - 2021-10-15 03:30:57,332 INFO: Testing - Epoch 10 - LogMetricByEpochHook: Loss = 1.41242, Accuracy = 0.48500\n"
]
}
]
}
]
}

View File

@ -3,26 +3,18 @@
import colossalai
from colossalai.core import global_context as gpc
from colossalai.engine import Engine
from colossalai.logging import get_global_dist_logger
from colossalai.trainer import Trainer
def run_trainer():
model, train_dataloader, test_dataloader, criterion, optimizer, schedule, lr_scheduler = colossalai.initialize()
engine, train_dataloader, test_dataloader = colossalai.initialize()
logger = get_global_dist_logger()
schedule.data_sync = False
engine = Engine(
model=model,
criterion=criterion,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
schedule=schedule
)
engine.schedule.data_sync = False
logger.info("engine is built", ranks=[0])
trainer = Trainer(engine=engine,
hooks_cfg=gpc.config.hooks,
verbose=True)
logger.info("trainer is built", ranks=[0])
@ -30,7 +22,8 @@ def run_trainer():
trainer.fit(
train_dataloader=train_dataloader,
test_dataloader=test_dataloader,
max_epochs=gpc.config.num_epochs,
epochs=gpc.config.num_epochs,
hooks_cfg=gpc.config.hooks,
display_progress=True,
test_interval=2
)

View File

@ -3,5 +3,5 @@ torchvision>=0.9
numpy
tqdm
psutil
tensorboardX
tensorboard
packaging

View File

@ -121,7 +121,7 @@ if "--cuda_ext" in sys.argv:
install_requires = fetch_requirements('requirements/requirements.txt')
setup(
name='colossal-ai',
name='colossalai',
version='0.0.1-beta',
packages=find_packages(exclude=('csrc',
'tests',

View File

@ -27,8 +27,6 @@ train_data = dict(
dataloader=dict(
batch_size=BATCH_SIZE,
pin_memory=True,
# num_workers=1,
# shuffle=True,
)
)
@ -63,14 +61,6 @@ loss = dict(
type='CrossEntropyLoss2D',
)
# model = dict(
# type='VanillaResNet',
# block_type='ResNetBasicBlock',
# layers=[2, 2, 2, 2],
# num_cls=10
# )
model = dict(
type='VisionTransformerFromConfig',
tensor_splitting_cfg=dict(
@ -135,25 +125,26 @@ parallel = dict(
fp16 = dict(
mode=AMP_TYPE.PARALLEL,
initial_scale=2 ** 8
)
# fp16 = dict(
# mode=None,
# )
schedule = dict(
num_microbatches=2
)
lr_scheduler = dict(
type='LinearWarmupLR',
warmup_epochs=5
engine = dict(
schedule=dict(
num_microbatches=2
)
)
hooks = [
dict(
type='LRSchedulerHook',
by_epoch=True,
lr_scheduler_cfg=dict(
type='LinearWarmupLR',
warmup_steps=5
)
),
]
num_epochs = 60
logging = dict(
root_path='test_vit_2d_log'
)
seed = 100

View File

@ -124,14 +124,21 @@ parallel = dict(
tensor=dict(size=4, depth=1, mode='2.5d'),
)
lr_scheduler = dict(
type='LinearWarmupLR',
warmup_epochs=5
)
hooks = [
dict(
type='LRSchedulerHook',
by_epoch=True,
lr_scheduler_cfg=dict(
type='LinearWarmupLR',
warmup_steps=5
)
),
]
engine = dict(
schedule = dict(
num_microbatches=2
)
)
num_epochs = 60
num_microbatches = 1

View File

@ -9,21 +9,22 @@ import torch.autograd
import colossalai
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.engine import Engine
from colossalai.logging import get_global_dist_logger
from colossalai.nn.layer._parallel_utilities import _gather
CONFIG_PATH = Path(__file__).parent.parent.joinpath('configs/vit_2d.py')
def eval(engine):
def eval(engine, test_dataloader):
engine.eval()
accumulated_loss = 0
correct_sum = 0
total_sum = 0
num_steps = len(test_dataloader)
data_iter = iter(test_dataloader)
for i in range(engine.schedule.num_steps):
output, label, loss = engine.step()
for i in range(num_steps):
output, label, loss = engine.step(data_iter)
if gpc.is_last_rank(ParallelMode.PIPELINE):
# loss = sum(loss)
@ -43,20 +44,22 @@ def eval(engine):
correct = torch.sum(label == output)
correct_sum += correct
total_sum += label.size(0)
avg_loss = accumulated_loss / engine.schedule.num_steps
avg_loss = accumulated_loss / num_steps
return correct_sum, total_sum, avg_loss
def train(engine):
def train(engine, train_dataloader):
engine.train()
accumulated_loss = 0
num_steps = len(train_dataloader)
data_iter = iter(train_dataloader)
for i in range(engine.schedule.num_steps):
output, label, loss = engine.step()
for i in range(num_steps):
output, label, loss = engine.step(data_iter)
if gpc.is_last_rank(ParallelMode.PIPELINE):
accumulated_loss += loss.detach().cpu().numpy()
avg_loss = accumulated_loss / engine.schedule.num_steps
avg_loss = accumulated_loss / num_steps
return avg_loss
@ -64,25 +67,16 @@ def train(engine):
@pytest.mark.skip("This test should be invoked by test.sh in the same folder as it runs on multiple gpus")
def test_2d_parallel_vision_transformer():
# init dist
model, train_dataloader, test_dataloader, criterion, optimizer, schedule, lr_scheduler = colossalai.initialize(
CONFIG_PATH)
engine, train_dataloader, test_dataloader = colossalai.initialize(CONFIG_PATH)
logger = get_global_dist_logger()
engine = Engine(model=model,
train_dataloader=train_dataloader,
test_dataloader=test_dataloader,
criterion=criterion,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
schedule=schedule)
for epoch in range(gpc.config.num_epochs):
train_loss = train(engine)
train_loss = train(engine, train_dataloader)
if gpc.is_last_rank(ParallelMode.PIPELINE):
logger.info(f'epoch {epoch} - train loss: {train_loss}')
if epoch % 2 == 0:
correct_sum, total_sum, eval_loss = eval(engine)
correct_sum, total_sum, eval_loss = eval(engine, test_dataloader)
if gpc.is_last_rank(ParallelMode.PIPELINE):
logger.info(
f'epoch {epoch} - eval loss: {eval_loss}, total: {total_sum}, '

View File

@ -6,20 +6,22 @@ import torch.autograd
import colossalai
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.engine import Engine
from colossalai.logging import get_global_dist_logger
from colossalai.nn.layer._parallel_utilities import _gather
CONFIG_PATH = Path(__file__).parent.parent.joinpath('configs/vit_2p5d.py')
def eval(engine):
def eval(engine, test_dataloader):
engine.eval()
accumulated_loss = 0
correct_sum = 0
total_sum = 0
num_steps = len(test_dataloader)
data_iter = iter(test_dataloader)
for i in range(engine.schedule.num_steps):
output, label, loss = engine.step()
for i in range(num_steps):
output, label, loss = engine.step(data_iter)
if gpc.is_last_rank(ParallelMode.PIPELINE):
accumulated_loss += loss.detach().cpu().numpy()
@ -43,21 +45,23 @@ def eval(engine):
correct = torch.sum(label == output)
correct_sum += correct
total_sum += label.size(0)
avg_loss = accumulated_loss / engine.schedule.num_steps
avg_loss = accumulated_loss / num_steps
return correct_sum, total_sum, avg_loss
def train(engine):
def train(engine, train_dataloader):
engine.train()
accumulated_loss = 0
num_steps = len(train_dataloader)
data_iter = iter(train_dataloader)
for i in range(num_steps):
output, label, loss = engine.step(data_iter)
for i in range(engine.schedule.num_steps):
output, label, loss = engine.step()
if gpc.is_last_rank(ParallelMode.PIPELINE):
accumulated_loss += loss.detach().cpu().numpy()
avg_loss = accumulated_loss / engine.schedule.num_steps
avg_loss = accumulated_loss / num_steps
return avg_loss
@ -65,25 +69,16 @@ def train(engine):
@pytest.mark.skip("This test should be invoked by test.sh in the same folder as it runs on multiple gpus")
def test_2p5d_parallel_vision_transformer():
# init dist
model, train_dataloader, test_dataloader, criterion, optimizer, schedule, lr_scheduler = colossalai.initialize(
CONFIG_PATH)
engine, train_dataloader, test_dataloader = colossalai.initialize(CONFIG_PATH)
logger = get_global_dist_logger()
engine = Engine(model=model,
train_dataloader=train_dataloader,
test_dataloader=test_dataloader,
criterion=criterion,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
schedule=schedule)
for epoch in range(gpc.config.num_epochs):
train_loss = train(engine)
train_loss = train(engine, train_dataloader)
if gpc.is_last_rank(ParallelMode.PIPELINE):
logger.info(f'epoch {epoch} - train loss: {train_loss}')
if epoch % 2 == 0:
correct_sum, total_sum, eval_loss = eval(engine)
correct_sum, total_sum, eval_loss = eval(engine, test_dataloader)
if gpc.is_last_rank(ParallelMode.PIPELINE):
logger.info(
f'epoch {epoch} - eval loss: {eval_loss}, total: {total_sum}, '
@ -91,4 +86,4 @@ def test_2p5d_parallel_vision_transformer():
if __name__ == '__main__':
test_2p5d_parallel_vision_transformer()
test_2p5d_parallel_vision_transformer()

View File

@ -38,5 +38,3 @@ optimizer = dict(type='Adam', lr=0.001)
loss = dict(type='CrossEntropyLoss')
# set_device_func = lambda global_rank, world_size: global_rank % 4
seed = 1024

View File

@ -40,6 +40,3 @@ optimizer = dict(type='Adam', lr=0.001)
loss = dict(type='CrossEntropyLoss')
fp16 = dict(mode=AMP_TYPE.APEX)
# set_device_func = lambda global_rank, world_size: global_rank % 4
seed = 1024

View File

@ -40,6 +40,3 @@ optimizer = dict(type='Adam', lr=0.001)
loss = dict(type='CrossEntropyLoss')
fp16 = dict(mode=AMP_TYPE.TORCH)
# set_device_func = lambda global_rank, world_size: global_rank % 4
seed = 1024

View File

@ -38,11 +38,9 @@ parallel = dict(
tensor=dict(size=1, mode=None)
)
schedule = dict(
num_microbatches=4
engine = dict(
schedule=dict(
num_microbatches=4
)
)
num_pipeling_batches = 2
seed = 1024
lr_scheduler = dict(type='LinearWarmupLR', warmup_steps=5)
num_epochs = 10

View File

@ -8,7 +8,6 @@ import torch
from colossalai import initialize
from colossalai.core import global_context as gpc
from colossalai.engine import Engine
from colossalai.logging import get_global_dist_logger
from colossalai.utils import report_memory_usage
@ -24,20 +23,13 @@ NO_PIPE_CONFIG_PATH = osp.join(DIR_PATH, '../configs/non_pipeline_resnet_apex_am
def run_no_pipeline(config):
model, train_dataloader, test_dataloader, criterion, optimizer, schedule, lr_scheduler = initialize(config)
engine, train_dataloader, test_dataloader = initialize(config)
logger = get_global_dist_logger()
rank = torch.distributed.get_rank()
engine = Engine(model=model,
train_dataloader=train_dataloader,
criterion=criterion,
optimizer=optimizer,
schedule=schedule)
engine.train()
logger.info('lr = %g' % engine.get_lr())
output, label, loss = engine.step()
output, label, loss = engine.step(iter(train_dataloader))
logger.info('Rank {} returns: {}'.format(rank, loss.item()))
logger.info('lr = %g' % engine.get_lr())
gpc.destroy()
logger.info('Test engine finished')

View File

@ -8,7 +8,6 @@ import torch
from colossalai import initialize
from colossalai.core import global_context as gpc
from colossalai.engine import Engine
from colossalai.logging import get_global_dist_logger
from colossalai.utils import report_memory_usage
@ -26,21 +25,14 @@ NO_PIPE_CONFIG_PATH = osp.join(DIR_PATH, '../configs/non_pipeline_resnet.py')
def test_no_pipeline(config):
print('Test no pipeline engine start')
model, train_dataloader, test_dataloader, criterion, optimizer, schedule, lr_scheduler = initialize(config)
engine, train_dataloader, test_dataloader = initialize(config)
logger = get_global_dist_logger()
rank = torch.distributed.get_rank()
engine = Engine(model=model,
train_dataloader=train_dataloader,
criterion=criterion,
optimizer=optimizer,
schedule=schedule)
engine.train()
logger.info('lr = %g' % engine.get_lr())
output, label, loss = engine.step()
output, label, loss = engine.step(iter(train_dataloader))
logger.info('Rank {} returns: {}'.format(rank, loss.item()))
logger.info('lr = %g' % engine.get_lr())
gpc.destroy()
logger.info('Test engine finished')

View File

@ -8,7 +8,6 @@ import torch
from colossalai import initialize
from colossalai.core import global_context as gpc
from colossalai.engine import Engine
from colossalai.logging import get_global_dist_logger
from colossalai.utils import report_memory_usage
@ -26,21 +25,13 @@ NO_PIPE_CONFIG_PATH = osp.join(DIR_PATH, '../configs/non_pipeline_resnet_torch_a
def test_no_pipeline(config):
print('Test no pipeline engine start')
model, train_dataloader, test_dataloader, criterion, optimizer, schedule, lr_scheduler = initialize(config)
engine, train_dataloader, test_dataloader = initialize(config)
logger = get_global_dist_logger()
rank = torch.distributed.get_rank()
engine = Engine(model=model,
train_dataloader=train_dataloader,
criterion=criterion,
optimizer=optimizer,
schedule=schedule)
engine.train()
logger.info('lr = %g' % engine.get_lr())
output, label, loss = engine.step()
output, label, loss = engine.step(iter(train_dataloader))
logger.info('Rank {} returns: {}'.format(rank, loss.item()))
logger.info('lr = %g' % engine.get_lr())
gpc.destroy()
logger.info('Test engine finished')

View File

@ -5,6 +5,7 @@ import os.path as osp
import pytest
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.initialize import initialize
from colossalai.logging import get_global_dist_logger
@ -22,13 +23,25 @@ CONFIG_PATH = osp.join(DIR_PATH, '../configs/pipeline_vanilla_resnet.py')
@pytest.mark.skip("This test should be invoked using the test.sh provided")
@pytest.mark.dist
def test_schedule():
model, train_dataloader, test_dataloader, criterion, optimizer, schedule, lr_scheduler = initialize(CONFIG_PATH)
engine, train_dataloader, test_dataloader = initialize(CONFIG_PATH)
logger = get_global_dist_logger()
schedule.zero_grad()
output, label, losses = schedule.forward_backward_step(forward_only=False)
schedule.step()
logger.info('losses: {}'.format([loss.item() for loss in losses]))
model = engine.model
optimizer = engine.optimizer
criterion = engine.criterion
schedule = engine._schedule
output, label, loss = schedule.forward_backward_step(
data_iter=iter(train_dataloader),
model=model,
optimizer=optimizer,
criterion=criterion,
forward_only=False
)
schedule.optimizer_step(model, optimizer)
if gpc.is_last_rank(ParallelMode.PIPELINE):
logger.info('losses: {}'.format(loss))
gpc.destroy()
logger.info('training finished')

View File

@ -9,7 +9,6 @@ import torch
from colossalai import initialize
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.engine import Engine
from colossalai.logging import get_global_dist_logger
NUM_BATCH = 128
@ -23,22 +22,14 @@ PIPE_CONFIG_PATH = osp.join(DIR_PATH, '../configs/pipeline_vanilla_resnet.py')
def run_pipeline(config):
model, train_dataloader, test_dataloader, criterion, optimizer, schedule, lr_scheduler = initialize(config)
engine, train_dataloader, test_dataloader = initialize(config)
logger = get_global_dist_logger()
rank = torch.distributed.get_rank()
engine = Engine(model=model,
train_dataloader=train_dataloader,
criterion=criterion,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
schedule=schedule)
engine.train()
logger.info('lr = %g' % engine.get_lr())
outputs, labels, loss = engine.step()
outputs, labels, loss = engine.step(iter(train_dataloader))
if gpc.is_last_rank(ParallelMode.PIPELINE):
logger.info('losses: {}'.format(rank, loss.item()))
logger.info('lr = %g' % engine.get_lr())
gpc.destroy()
logger.info('Test engine pipeline finished')

View File

@ -132,9 +132,12 @@ fp16 = dict(
initial_scale=2 ** 4
)
num_epochs = 60
lr_scheduler = dict(
type='LinearWarmupLR',
warmup_epochs=5
warmup_steps=5,
total_steps=num_epochs
)
num_epochs = 60

View File

@ -7,23 +7,25 @@ import pytest
import torch.autograd
import colossalai
from colossalai.builder import build_lr_scheduler
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.engine import Engine
from colossalai.logging import get_global_dist_logger
from colossalai.nn.layer._parallel_utilities import _gather
CONFIG_PATH = Path(__file__).parent.parent.joinpath('configs/vit_2d.py')
def eval(engine):
def eval(engine, test_dataloader):
engine.eval()
accumulated_loss = 0
correct_sum = 0
total_sum = 0
num_steps = len(test_dataloader)
data_iter = iter(test_dataloader)
for i in range(engine.schedule.num_steps):
output, label, loss = engine.step()
for i in range(num_steps):
output, label, loss = engine.step(data_iter)
accumulated_loss += loss.detach().cpu().numpy()
output = _gather(
@ -40,18 +42,21 @@ def eval(engine):
correct = torch.sum(label[0] == output)
correct_sum += correct
total_sum += label[0].size(0)
avg_loss = accumulated_loss / engine.schedule.num_steps
avg_loss = accumulated_loss / num_steps
return correct_sum, total_sum, avg_loss
def train(engine):
def train(engine, train_dataloader, lr_scheduler):
engine.train()
accumulated_loss = 0
num_steps = len(train_dataloader)
data_iter = iter(train_dataloader)
for i in range(engine.schedule.num_steps):
output, label, loss = engine.step()
for i in range(num_steps):
output, label, loss = engine.step(data_iter)
accumulated_loss += loss.squeeze(0).detach().cpu().numpy()
avg_loss = accumulated_loss / engine.schedule.num_steps
avg_loss = accumulated_loss / num_steps
lr_scheduler.step()
return avg_loss
@ -59,26 +64,18 @@ def train(engine):
@pytest.mark.skip("This test should be invoked by test.sh in the same folder as it runs on multiple gpus")
def test_2d_parallel_vision_transformer():
# init dist
model, train_dataloader, test_dataloader, criterion, optimizer, schedule, lr_scheduler = colossalai.initialize(
CONFIG_PATH)
engine, train_dataloader, test_dataloader = colossalai.initialize(CONFIG_PATH)
lr_scheduler = build_lr_scheduler(gpc.config.lr_scheduler, engine.optimizer)
logger = get_global_dist_logger()
engine = Engine(model=model,
train_dataloader=train_dataloader,
test_dataloader=test_dataloader,
criterion=criterion,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
schedule=schedule)
logger.info('start training')
for epoch in range(gpc.config.num_epochs):
train_loss = train(engine)
train_loss = train(engine, train_dataloader, lr_scheduler)
logger.info(f'epoch {epoch} - train loss: {train_loss}')
if epoch % 2 == 0:
correct_sum, total_sum, eval_loss = eval(engine)
correct_sum, total_sum, eval_loss = eval(engine, test_dataloader)
logger.info(
f'epoch {epoch} - eval loss: {eval_loss}, total: {total_sum}, '
f'correct: {correct_sum}, acc: {correct_sum / total_sum}')

View File

@ -102,6 +102,6 @@ parallel = dict(
tensor=dict(size=4, mode='2d'),
)
lr_scheduler = dict(type='LinearWarmupLR', warmup_epochs=5)
num_epochs = 60
lr_scheduler = dict(type='LinearWarmupLR', warmup_steps=5, total_steps=num_epochs)

View File

@ -125,13 +125,6 @@ parallel = dict(
tensor=dict(size=4, depth=1, mode='2.5d'),
)
lr_scheduler = dict(
type='LinearWarmupLR',
warmup_epochs=5
)
schedule = dict(
num_microbatches=8
)
num_epochs = 60
lr_scheduler = dict(type='LinearWarmupLR', warmup_steps=5, total_steps=num_epochs)

View File

@ -116,9 +116,14 @@ hooks = [
weight_parallel_mode=ParallelMode.PARALLEL_3D_WEIGHT,
),
dict(type='LossHook'),
# dict(type='TensorboardHook', log_dir='./tfb_logs'),
# dict(type='SaveCheckpointHook', interval=5, checkpoint_dir='./ckpt'),
# dict(type='LoadCheckpointHook', epoch=20, checkpoint_dir='./ckpt')
dict(
type='LRSchedulerHook',
by_epoch=True,
lr_scheduler_cfg=dict(
type='LinearWarmupLR',
warmup_steps=5
)
),
]
parallel = dict(
@ -127,12 +132,4 @@ parallel = dict(
tensor=dict(mode='3d', size=8),
)
# fp16 = dict(mode=AMP_TYPE.PARALLEL, initial_scale=2 ** 6)
lr_scheduler = dict(type='LinearWarmupLR', warmup_epochs=5)
# schedule = dict(num_microbatches=4)
num_epochs = 60
seed = 42

View File

@ -7,23 +7,25 @@ import pytest
import torch.autograd
import colossalai
from colossalai.builder import build_lr_scheduler
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.engine import Engine
from colossalai.logging import get_global_dist_logger
from colossalai.nn.layer._parallel_utilities import _gather
CONFIG_PATH = Path(__file__).parent.parent.joinpath('configs/vit_2d.py')
def eval(engine):
def eval(engine, test_dataloader):
engine.eval()
accumulated_loss = 0
correct_sum = 0
total_sum = 0
num_steps = len(test_dataloader)
data_iter = iter(test_dataloader)
for i in range(engine.schedule.num_steps):
output, label, loss = engine.step()
for i in range(num_steps):
output, label, loss = engine.step(data_iter)
accumulated_loss += loss.detach().cpu().numpy()
output = _gather(
@ -40,18 +42,21 @@ def eval(engine):
correct = torch.sum(label[0] == output)
correct_sum += correct
total_sum += label[0].size(0)
avg_loss = accumulated_loss / engine.schedule.num_steps
avg_loss = accumulated_loss / num_steps
return correct_sum, total_sum, avg_loss
def train(engine):
def train(engine, train_dataloader, lr_scheduler):
engine.train()
accumulated_loss = 0
num_steps = len(train_dataloader)
data_iter = iter(train_dataloader)
for i in range(engine.schedule.num_steps):
output, label, loss = engine.step()
for i in range(num_steps):
output, label, loss = engine.step(data_iter)
accumulated_loss += loss.detach().cpu().numpy()
avg_loss = accumulated_loss / engine.schedule.num_steps
avg_loss = accumulated_loss / num_steps
lr_scheduler.step()
return avg_loss
@ -59,25 +64,17 @@ def train(engine):
@pytest.mark.skip("This test should be invoked by test.sh in the same folder as it runs on multiple gpus")
def test_2d_parallel_vision_transformer():
# init dist
model, train_dataloader, test_dataloader, criterion, optimizer, schedule, lr_scheduler = colossalai.initialize(
CONFIG_PATH)
engine, train_dataloader, test_dataloader = colossalai.initialize(CONFIG_PATH)
lr_scheduler = build_lr_scheduler(gpc.config.lr_scheduler, engine.optimizer)
logger = get_global_dist_logger()
engine = Engine(model=model,
train_dataloader=train_dataloader,
test_dataloader=test_dataloader,
criterion=criterion,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
schedule=schedule)
logger.info('start training')
for epoch in range(gpc.config.num_epochs):
train_loss = train(engine)
train_loss = train(engine, train_dataloader, lr_scheduler)
logger.info(f'epoch {epoch} - train loss: {train_loss}')
if epoch % 2 == 0:
correct_sum, total_sum, eval_loss = eval(engine)
correct_sum, total_sum, eval_loss = eval(engine, test_dataloader)
logger.info(
f'epoch {epoch} - eval loss: {eval_loss}, total: {total_sum}, '
f'correct: {correct_sum}, acc: {correct_sum / total_sum}')

View File

@ -4,22 +4,25 @@ import pytest
import torch.autograd
import colossalai
from colossalai.builder import build_lr_scheduler
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.engine import Engine
from colossalai.logging import get_global_dist_logger
from colossalai.nn.layer._parallel_utilities import _gather
CONFIG_PATH = Path(__file__).parent.parent.joinpath('configs/vit_2p5d.py')
def eval(engine):
def eval(engine, test_dataloader):
engine.eval()
accumulated_loss = 0
correct_sum = 0
total_sum = 0
num_steps = len(test_dataloader)
data_iter = iter(test_dataloader)
for i in range(engine.schedule.num_steps):
output, label, loss = engine.step()
for i in range(num_steps):
output, label, loss = engine.step(data_iter)
accumulated_loss += loss.detach().cpu().numpy()
output = _gather(
@ -41,18 +44,21 @@ def eval(engine):
correct = torch.sum(label[0] == output)
correct_sum += correct
total_sum += label[0].size(0)
avg_loss = accumulated_loss / engine.schedule.num_steps
avg_loss = accumulated_loss / num_steps
return correct_sum, total_sum, avg_loss
def train(engine):
def train(engine, train_dataloader, lr_scheduler):
engine.train()
accumulated_loss = 0
num_steps = len(train_dataloader)
data_iter = iter(train_dataloader)
for i in range(engine.schedule.num_steps):
output, label, loss = engine.step()
for i in range(num_steps):
output, label, loss = engine.step(data_iter)
accumulated_loss += loss.detach().cpu().numpy()
avg_loss = accumulated_loss / engine.schedule.num_steps
avg_loss = accumulated_loss / num_steps
lr_scheduler.step()
return avg_loss
@ -60,29 +66,21 @@ def train(engine):
@pytest.mark.skip("This test should be invoked by test.sh in the same folder as it runs on multiple gpus")
def test_2p5d_parallel_vision_transformer():
# init dist
model, train_dataloader, test_dataloader, criterion, optimizer, schedule, lr_scheduler = colossalai.initialize(
CONFIG_PATH)
engine, train_dataloader, test_dataloader = colossalai.initialize(CONFIG_PATH)
lr_scheduler = build_lr_scheduler(gpc.config.lr_scheduler, engine.optimizer)
logger = get_global_dist_logger()
engine = Engine(model=model,
train_dataloader=train_dataloader,
test_dataloader=test_dataloader,
criterion=criterion,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
schedule=schedule)
logger.info('start training')
for epoch in range(gpc.config.num_epochs):
train_loss = train(engine)
train_loss = train(engine, train_dataloader, lr_scheduler)
logger.info(f'epoch {epoch} - train loss: {train_loss}')
if epoch % 2 == 0:
correct_sum, total_sum, eval_loss = eval(engine)
correct_sum, total_sum, eval_loss = eval(engine, test_dataloader)
logger.info(
f'epoch {epoch} - eval loss: {eval_loss}, total: {total_sum}, '
f'correct: {correct_sum}, acc: {correct_sum / total_sum}')
if __name__ == '__main__':
test_2p5d_parallel_vision_transformer()
test_2p5d_parallel_vision_transformer()

View File

@ -1,16 +1,14 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import time
from pathlib import Path
import torch
from tqdm import tqdm
from colossalai import initialize
import colossalai
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.engine import Engine
from colossalai.logging import get_global_dist_logger
from colossalai.trainer import Trainer
from colossalai.trainer.metric import Accuracy3D
@ -29,7 +27,7 @@ def _train_epoch(epoch, engine):
num_samples = 0
now = time.time()
epoch_start = now
progress = range(engine.schedule.num_steps)
progress = range(engine._schedule.num_steps)
if gpc.get_global_rank() == 0:
progress = tqdm(progress, desc='[Epoch %d]' % epoch, miniters=1)
for step in progress:
@ -68,7 +66,7 @@ def _eval(epoch, engine):
ParallelMode.PARALLEL_3D_WEIGHT)
total = 0
with torch.no_grad():
for _ in range(engine.schedule.num_steps):
for _ in range(engine._schedule.num_steps):
outputs, targets, loss = engine.step()
if isinstance(outputs, (list, tuple)):
outputs = outputs[0]
@ -80,32 +78,25 @@ def _eval(epoch, engine):
print_rank_0(
'[Epoch %d] Evaluation loss: %.3f | Acc: %.3f%%' %
(epoch, eval_loss / engine.schedule.num_steps,
(epoch, eval_loss / engine._schedule.num_steps,
acc.get_accumulated_value() * 100), logger)
def train():
model, train_dataloader, test_dataloader, criterion, \
optimizer, schedule, lr_scheduler = initialize(CONFIG_PATH)
# init dist
engine, train_dataloader, test_dataloader = colossalai.initialize(CONFIG_PATH)
logger = get_global_dist_logger()
engine = Engine(model=model,
train_dataloader=train_dataloader,
test_dataloader=test_dataloader,
criterion=criterion,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
schedule=schedule)
logger.info("Engine is built", ranks=[0])
trainer = Trainer(engine=engine, hooks_cfg=gpc.config.hooks, verbose=True)
trainer = Trainer(engine=engine, verbose=True)
logger.info("Trainer is built", ranks=[0])
logger.info("Train start", ranks=[0])
trainer.fit(train_dataloader=train_dataloader,
test_dataloader=test_dataloader,
max_epochs=gpc.config.num_epochs,
epochs=gpc.config.num_epochs,
hooks_cfg=gpc.config.hooks,
display_progress=True,
test_interval=1)

View File

@ -3,6 +3,7 @@ from pathlib import Path
BATCH_SIZE = 128
IMG_SIZE = 32
num_epochs = 200
# resnet 50
model = dict(
@ -77,18 +78,14 @@ hooks = [
dict(type='AccuracyHook'),
dict(type='LossHook'),
dict(type='TensorboardHook', log_dir='./tfb_logs'),
dict(
type='LRSchedulerHook',
by_epoch=True,
lr_scheduler_cfg=dict(
type='CosineAnnealingLR',
warmup_steps=5
)
),
dict(type='SaveCheckpointHook', interval=5, checkpoint_dir='./ckpt'),
# dict(type='LoadCheckpointHook', epoch=20, checkpoint_dir='./ckpt')
]
# fp16 = dict(
# mode=AMP_TYPE.PARALLEL,
# initial_scale=1
# )
lr_scheduler = dict(
type='CosineAnnealingLR',
T_max=200
)
num_epochs = 200

View File

@ -11,6 +11,7 @@ NUM_ATTENTION_HEADS = 8
SUMMA_DIM = 2
NUM_CLASSES = 10
DEPTH = 6
num_epochs = 60
train_data = dict(
dataset=dict(type='CIFAR10Dataset',
@ -52,13 +53,6 @@ optimizer = dict(type='Adam', lr=0.001, weight_decay=0)
loss = dict(type='CrossEntropyLoss2D', )
# model = dict(
# type='VanillaResNet',
# block_type='ResNetBasicBlock',
# layers=[2, 2, 2, 2],
# num_cls=10
# )
model = dict(
type='VisionTransformerFromConfig',
tensor_splitting_cfg=dict(type='ViTInputSplitter2D', ),
@ -114,8 +108,15 @@ hooks = [
dict(type='Accuracy2DHook'),
dict(type='LossHook'),
dict(type='TensorboardHook', log_dir='./tfb_logs'),
dict(
type='LRSchedulerHook',
by_epoch=True,
lr_scheduler_cfg=dict(
type='LinearWarmupLR',
warmup_steps=5
)
),
dict(type='SaveCheckpointHook', interval=5, checkpoint_dir='./ckpt'),
# dict(type='LoadCheckpointHook', epoch=20, checkpoint_dir='./ckpt')
]
parallel = dict(
@ -125,11 +126,8 @@ parallel = dict(
fp16 = dict(mode=AMP_TYPE.PARALLEL, initial_scale=2 ** 8)
lr_scheduler = dict(type='LinearWarmupLR', warmup_epochs=5)
schedule = dict(num_microbatches=1)
num_epochs = 60
num_microbatches = 1
engine = dict(
schedule=dict(num_microbatches=1)
)
logging = dict(root_path='./logs')

View File

@ -1,25 +1,16 @@
import colossalai
from colossalai.core import global_context as gpc
from colossalai.engine import Engine
from colossalai.logging import get_global_dist_logger
from colossalai.trainer import Trainer
def test_trainer():
model, train_dataloader, test_dataloader, criterion, optimizer, schedule, lr_scheduler = colossalai.initialize()
engine, train_dataloader, test_dataloader = colossalai.initialize()
logger = get_global_dist_logger()
engine = Engine(
model=model,
criterion=criterion,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
schedule=schedule
)
logger.info("engine is built", ranks=[0])
trainer = Trainer(engine=engine,
hooks_cfg=gpc.config.hooks,
verbose=True)
logger.info("trainer is built", ranks=[0])
@ -27,7 +18,8 @@ def test_trainer():
trainer.fit(
train_dataloader=train_dataloader,
test_dataloader=test_dataloader,
max_epochs=gpc.config.num_epochs,
hooks_cfg=gpc.config.hooks,
epochs=gpc.config.num_epochs,
display_progress=False,
test_interval=5
)

View File

@ -18,14 +18,16 @@ level = os.environ['LEVEL']
CONFIG_PATH = Path(__file__).parent.parent.joinpath(f'configs/vit_2d_zero{level}.py')
def eval(engine):
def eval_epoch(engine: Engine, test_dataloader):
engine.eval()
accumulated_loss = 0
correct_sum = 0
total_sum = 0
num_steps = len(test_dataloader)
data_iter = iter(test_dataloader)
for i in range(engine.schedule.num_steps):
output, label, loss = engine.step()
for i in range(num_steps):
output, label, loss = engine.step(data_iter)
accumulated_loss += loss.detach().cpu().numpy()
output = _gather(
@ -42,18 +44,19 @@ def eval(engine):
correct = torch.sum(label[0] == output)
correct_sum += correct
total_sum += label[0].size(0)
avg_loss = accumulated_loss / engine.schedule.num_steps
avg_loss = accumulated_loss / num_steps
return correct_sum, total_sum, avg_loss
def train(engine):
def train_epoch(engine, train_dataloader):
engine.train()
accumulated_loss = 0
for i in range(engine.schedule.num_steps):
output, label, loss = engine.step()
num_steps = len(train_dataloader)
data_iter = iter(train_dataloader)
for i in range(num_steps):
output, label, loss = engine.step(data_iter)
accumulated_loss += loss.detach().cpu().numpy()
avg_loss = accumulated_loss / engine.schedule.num_steps
avg_loss = accumulated_loss / num_steps
return avg_loss
@ -61,30 +64,17 @@ def train(engine):
@pytest.mark.skip("This test should be invoked by test.sh in the same folder as it runs on multiple gpus")
def test_2d_parallel_vision_transformer():
# init dist
model, train_dataloader, test_dataloader, criterion, optimizer, schedule, lr_scheduler = colossalai.initialize(
CONFIG_PATH)
engine, train_dataloader, test_dataloader = colossalai.initialize(CONFIG_PATH)
logger = get_global_dist_logger()
engine = Engine(model=model,
train_dataloader=train_dataloader,
test_dataloader=test_dataloader,
criterion=criterion,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
schedule=schedule)
# for param in model.parameters():
# if isinstance(param, torch.HalfTensor):
# print(param.shape)
logger.info('start training')
for epoch in range(gpc.config.num_epochs):
train_loss = train(engine)
train_loss = train_epoch(engine, train_dataloader)
logger.info(f'epoch {epoch} - train loss: {train_loss}')
if epoch % 2 == 0:
correct_sum, total_sum, eval_loss = eval(engine)
correct_sum, total_sum, eval_loss = eval_epoch(engine, test_dataloader)
logger.info(
f'epoch {epoch} - eval loss: {eval_loss}, total: {total_sum}, '
f'correct: {correct_sum}, acc: {correct_sum / total_sum}')