From 3defa32aee5c8cac42b0625df258254d11cfaad7 Mon Sep 17 00:00:00 2001 From: Frank Lee Date: Thu, 18 Nov 2021 19:45:06 +0800 Subject: [PATCH] Support TP-compatible Torch AMP and Update trainer API (#27) * Add gradient accumulation, fix lr scheduler * fix FP16 optimizer and adapted torch amp with tensor parallel (#18) * fixed bugs in compatibility between torch amp and tensor parallel and performed some minor fixes * fixed trainer * Revert "fixed trainer" This reverts commit 2e0b0b76990e8d4e337add483d878c0f61cf5097. * improved consistency between trainer, engine and schedule (#23) Co-authored-by: 1SAA Co-authored-by: 1SAA Co-authored-by: ver217 --- README.md | 14 +- colossalai/builder/__init__.py | 10 +- colossalai/builder/builder.py | 33 +- colossalai/engine/__init__.py | 2 +- colossalai/engine/_base_engine.py | 210 ++--- colossalai/engine/amp/__init__.py | 2 + colossalai/engine/{ => amp}/amp_type.py | 0 colossalai/engine/amp/grad_scaler.py | 577 ++++++++++++++ colossalai/engine/schedule/_base_schedule.py | 150 ++-- colossalai/engine/schedule/_no_pipeline.py | 131 ++-- colossalai/engine/schedule/_pipeline.py | 91 +-- colossalai/engine/schedule/_utils.py | 11 + colossalai/initialize.py | 56 +- colossalai/nn/layer/parallel_2d/_operation.py | 156 ++-- colossalai/nn/lr_scheduler/__init__.py | 2 +- colossalai/nn/lr_scheduler/cosine.py | 7 +- colossalai/nn/lr_scheduler/delayed.py | 15 +- colossalai/nn/lr_scheduler/linear.py | 15 - colossalai/nn/lr_scheduler/multistep.py | 16 +- colossalai/nn/lr_scheduler/torch.py | 42 +- colossalai/nn/optimizer/_utils.py | 2 +- .../zero_redundancy_optimizer_level_2.py | 15 +- .../zero_redundancy_optimizer_level_3.py | 2 +- colossalai/registry/__init__.py | 1 + colossalai/trainer/__init__.py | 4 +- colossalai/trainer/_trainer.py | 350 +++++---- colossalai/trainer/hooks/__init__.py | 2 + colossalai/trainer/hooks/_checkpoint_hook.py | 70 +- colossalai/trainer/hooks/_log_hook.py | 84 +- .../trainer/hooks/_lr_scheduler_hook.py | 58 ++ colossalai/trainer/hooks/_metric_hook.py | 64 +- colossalai/trainer/metric.py | 27 + colossalai/{ => utils}/checkpointing.py | 6 +- colossalai/utils/common.py | 2 +- configs/resnet/resnet50.py | 3 +- configs/sample_config.py | 7 +- configs/vit/vit_2d.py | 23 +- configs/vit/vit_3d.py | 16 +- .../colossalai.engine.amp.amp_type.rst | 5 + .../colossalai.engine.amp.grad_scaler.rst | 5 + docs/colossalai/colossalai.engine.amp.rst | 12 + .../colossalai/colossalai.engine.amp_type.rst | 5 - docs/colossalai/colossalai.engine.rst | 7 +- docs/colossalai/colossalai.rst | 1 - .../colossalai.utils.checkpointing.rst | 5 + docs/colossalai/colossalai.utils.rst | 1 + docs/parallelization.md | 55 +- docs/run_demo.md | 30 +- docs/run_demo_zh.md | 28 +- docs/trainer_engine.md | 69 +- docs/trainer_engine_zh.md | 19 +- examples/colossal_cifar_demo.ipynb | 728 +++++++++--------- examples/run_trainer.py | 17 +- requirements/requirements.txt | 2 +- setup.py | 2 +- .../configs/vit_2d.py | 37 +- .../configs/vit_2p5d.py | 17 +- .../test_vit_2d/test_vit_2d.py | 36 +- .../test_vit_2p5d/test_vit_2p5d.py | 41 +- .../configs/non_pipeline_resnet.py | 2 - .../configs/non_pipeline_resnet_apex_amp.py | 3 - .../configs/non_pipeline_resnet_torch_amp.py | 3 - .../configs/pipeline_vanilla_resnet.py | 10 +- .../test_engine_apex_amp.py | 12 +- .../test_engine_no_amp.py | 12 +- .../test_engine_torch_amp.py | 13 +- .../test_pipeline/test_schedule.py | 23 +- .../test_pipeline_engine/test_engine.py | 13 +- tests/test_fp16_optimizer/configs/vit_2d.py | 7 +- .../test_vit_2d/test_vit_2d.py | 39 +- .../test_vision_transformer/configs/vit_2d.py | 4 +- .../configs/vit_2p5d.py | 11 +- .../test_vision_transformer/configs/vit_3d.py | 19 +- .../test_vit_2d/test_vit_2d.py | 39 +- .../test_vit_2p5d/test_vit_2p5d.py | 42 +- .../test_vit_3d/test_vit_3d.py | 27 +- .../configs/test_trainer_resnet.py | 21 +- .../configs/test_trainer_vit_2d.py | 26 +- tests/test_trainer/test_trainer.py | 14 +- .../test_vit_2d/test_vit_2d.py | 40 +- 80 files changed, 2194 insertions(+), 1584 deletions(-) create mode 100644 colossalai/engine/amp/__init__.py rename colossalai/engine/{ => amp}/amp_type.py (100%) create mode 100644 colossalai/engine/amp/grad_scaler.py create mode 100644 colossalai/trainer/hooks/_lr_scheduler_hook.py rename colossalai/{ => utils}/checkpointing.py (98%) create mode 100644 docs/colossalai/colossalai.engine.amp.amp_type.rst create mode 100644 docs/colossalai/colossalai.engine.amp.grad_scaler.rst create mode 100644 docs/colossalai/colossalai.engine.amp.rst delete mode 100644 docs/colossalai/colossalai.engine.amp_type.rst create mode 100644 docs/colossalai/colossalai.utils.checkpointing.rst diff --git a/README.md b/README.md index 6e6c8de81..f5f16a725 100644 --- a/README.md +++ b/README.md @@ -42,26 +42,18 @@ pip install -v --no-cache-dir --global-option="--cuda_ext" . ```python import colossalai -from colossalai.engine import Engine from colossalai.trainer import Trainer from colossalai.core import global_context as gpc -model, train_dataloader, test_dataloader, criterion, optimizer, schedule, lr_scheduler = colossalai.initialize() -engine = Engine( - model=model, - criterion=criterion, - optimizer=optimizer, - lr_scheduler=lr_scheduler, - schedule=schedule -) +engine, train_dataloader, test_dataloader = colossalai.initialize() trainer = Trainer(engine=engine, - hooks_cfg=gpc.config.hooks, verbose=True) trainer.fit( train_dataloader=train_dataloader, test_dataloader=test_dataloader, - max_epochs=gpc.config.num_epochs, + epochs=gpc.config.num_epochs, + hooks_cfg=gpc.config.hooks, display_progress=True, test_interval=5 ) diff --git a/colossalai/builder/__init__.py b/colossalai/builder/__init__.py index 17d643285..2ae194132 100644 --- a/colossalai/builder/__init__.py +++ b/colossalai/builder/__init__.py @@ -1,2 +1,10 @@ -from .builder import * +from .builder import (build_schedule, build_lr_scheduler, build_model, build_optimizer, build_optimizer_wrapper, + build_layer, build_loss, build_hooks, build_dataset, build_transform, build_data_sampler, + build_gradient_handler) from .pipeline import ModelInitializer + +__all__ = [ + 'build_schedule', 'build_lr_scheduler', 'build_model', 'build_optimizer', 'build_optimizer_wrapper', + 'build_layer', 'build_loss', 'build_hooks', 'build_dataset', 'build_transform', 'build_data_sampler', + 'build_gradient_handler', 'ModelInitializer' +] diff --git a/colossalai/builder/builder.py b/colossalai/builder/builder.py index f88dc1cbf..c32ad3b39 100644 --- a/colossalai/builder/builder.py +++ b/colossalai/builder/builder.py @@ -181,18 +181,6 @@ def build_transform(config): return build_from_registry(config, TRANSFORMS) -def build_pipe_alloc_policy(config): - """Returns a pipeline allocation policy object constructed from `config`. - - :param config: A python dict or a :class:`colossalai.context.Config` object - containing information used in the construction of the return object - :type config: dict or :class:`colossalai.context.Config` - :return: A pipeline allocation policy object - :rtype: - """ - return build_from_registry(config, PIPE_ALLOC_POLICY) - - def build_data_sampler(config, dataset): """Returns a data sampler object of :class:`colossalai.nn.data.sampler.BaseSampler` constructed from `config`. @@ -235,7 +223,7 @@ def build_optimizer_wrapper(config, optimizer, model=None): return OPTIMIZER_WRAPPERS.get_module(mod_type)(optimizer, **config_) -def build_lr_scheduler(config, optimizer, total_steps, num_steps_per_epoch): +def build_lr_scheduler(config, optimizer): """Returns a learning rate scheduler object of :class:`torch.optim.lr_scheduler` constructed from `config`, `optimizer`, `total_steps` and `num_steps_per_epoch`. @@ -254,9 +242,16 @@ def build_lr_scheduler(config, optimizer, total_steps, num_steps_per_epoch): """ config_ = config.copy() mod_type = config_.pop('type') - # warmup epochs will overwrite warmup steps - if 'warmup_epochs' in config_: - warmup_epochs = config_.pop('warmup_epochs') - config_['warmup_steps'] = int(num_steps_per_epoch * warmup_epochs) - return LR_SCHEDULERS.get_module(mod_type)(optimizer, total_steps, num_steps_per_epoch=num_steps_per_epoch, - **config_) + return LR_SCHEDULERS.get_module(mod_type)(optimizer, **config_) + + +def build_schedule(config): + """Returns a schedule of :class:`colossalai.engine.schedule.BaseSchedule`. + + :param config: A python dict or a :class:`colossalai.context.Config` object + containing information used in the construction of the return object + :type config: dict or :class:`colossalai.context.Config` + :return: An object of :class:`colossalai.engine.schedule.BaseSchedule` + :rtype: :class:`colossalai.engine.schedule.BaseSchedule` + """ + return build_from_registry(config, SCHEDULE) diff --git a/colossalai/engine/__init__.py b/colossalai/engine/__init__.py index c00be7df6..7e5592236 100644 --- a/colossalai/engine/__init__.py +++ b/colossalai/engine/__init__.py @@ -1,7 +1,7 @@ -from .amp_type import AMP_TYPE from ._base_engine import Engine from .gradient_handler import * from .schedule import * +from .amp import * __all__ = ['Engine'] diff --git a/colossalai/engine/_base_engine.py b/colossalai/engine/_base_engine.py index 843ef1d4f..a99aa91e7 100644 --- a/colossalai/engine/_base_engine.py +++ b/colossalai/engine/_base_engine.py @@ -1,7 +1,9 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- -from typing import Optional +from torch.nn import Module +from torch.nn.modules.loss import _Loss +from torch.optim import Optimizer from colossalai.builder import build_gradient_handler from colossalai.context import ParallelMode @@ -9,89 +11,103 @@ from colossalai.core import global_context as gpc from colossalai.logging import get_global_dist_logger from colossalai.nn import (ZeroRedundancyOptimizer_Level_2, ZeroRedundancyOptimizer_Level_3) -from torch.nn import Module -from torch.nn.modules.loss import _Loss -from torch.optim import Optimizer -from torch.optim.lr_scheduler import _LRScheduler -from torch.utils.data import DataLoader - -from .schedule import BaseSchedule, NoPipelineSchedule +from .schedule import BaseSchedule class Engine: """Basic engine class for training and evaluation. It runs a specific process method :meth:`step` which is based on the given :attr:`schedule` over each batch of a dataset. + It controls a iteration in training. - :param train_dataloader: Dataloader in training - :param test_dataloader: Dataloader in evaluation :param model: The neural network model - :param criterion: Criterion for calculating loss :param optimizer: Optimizer for updating the parameters - :param lr_scheduler: Learning rate scheduler ajusting learning rate during the training or evaluation - :param schedule: Running schedule in :meth:`step` - :type train_dataloader: DataLoader, optional - :type test_dataloader: DataLoader, optional + :param step_schedule: Running schedule in :meth:`step` + :param gradient_accumulation: Steps of gradient accumulation + :param gradient_clipping: The norm of gradient clipping :type model: Module - :type criterion: _Loss, optional - :type optimizer: Optimizer, optional - :type lr_scheduler: _LRScheduler, optional - :type schedule: BaseSchedule, optional + :type optimizer: Optimizer + :type step_schedule: BaseSchedule, optional + :type gradient_accumulation: int, optional + :type gradient_clipping: float, optional """ + def __init__(self, - train_dataloader: Optional[DataLoader] = None, - test_dataloader: Optional[DataLoader] = None, - model: Module = None, - criterion: _Loss = None, - optimizer: Optimizer = None, - lr_scheduler: Optional[_LRScheduler] = None, - schedule: BaseSchedule = None): - self.train_dataloader = train_dataloader - self.test_dataloader = test_dataloader - assert model is not None, "Engine requires a model" - self.model = model - self.criterion = criterion - self.optimizer = optimizer - self.lr_scheduler = lr_scheduler - self.schedule = schedule if schedule is not None \ - else NoPipelineSchedule() + model: Module, + optimizer: Optimizer, + criterion: _Loss, + step_schedule: BaseSchedule, + gradient_handlers: list = None, + gradient_accumulation: int = 1, + gradient_clipping: float = 0.0, + ): + self._model = model + self._optimizer = optimizer + self._criterion = criterion + self._schedule = step_schedule + + # schedule initialize + self._schedule.initialize(model, optimizer) + + # state + self.training = True # default + + # gradient accumulation + assert gradient_accumulation > 0, 'gradient accumulation size must be larger than 0' + self._grad_accum_size = gradient_accumulation + self._grad_clip = gradient_clipping self._logger = get_global_dist_logger() # build gradient handler self._gradient_handlers = [] - gradient_handler_cfg = [] - if hasattr(gpc.config, 'gradient_handler'): - assert isinstance(gpc.config.gradient_handler, list), \ + if gradient_handlers is not None: + assert isinstance(gradient_handlers, list), \ f'argument gradient_handler_cfg expected type list, ' \ - f'but got type {type(gpc.config.gradient_handler)}' - gradient_handler_cfg = gpc.config.gradient_handler - elif isinstance(self.optimizer, (ZeroRedundancyOptimizer_Level_2, - ZeroRedundancyOptimizer_Level_3)): - gradient_handler_cfg = [dict(type='ZeROGradientHandler')] + f'but got type {type(gradient_handlers)}' + elif isinstance(optimizer, (ZeroRedundancyOptimizer_Level_2, + ZeroRedundancyOptimizer_Level_3)): + gradient_handlers = [dict(type='ZeROGradientHandler')] self._logger.info( "Training with zero is detected, ZeROGradientHandler is automatically " "added even though not specified in the configuration", ranks=[0]) elif gpc.is_initialized(ParallelMode.DATA) and gpc.get_world_size( ParallelMode.DATA) > 1: - gradient_handler_cfg = [dict(type='DataParallelGradientHandler')] + gradient_handlers = [dict(type='DataParallelGradientHandler')] self._logger.info( "Data parallel training is detected, DataParallelGradientHandler is automatically " "added even though not specified in the configuration", ranks=[0]) - if len(gradient_handler_cfg) == 0: + + if gradient_handlers is None: self._logger.warning( "No gradient handler is set up, please make sure you do not need " "to all-reduce the gradients after a training step.", ranks=[0]) - for cfg in gradient_handler_cfg: - handler = build_gradient_handler(cfg, self.model, self.optimizer) - self._gradient_handlers.append(handler) + else: + for cfg in gradient_handlers: + handler = build_gradient_handler(cfg, model, optimizer) + self._gradient_handlers.append(handler) - self.schedule.initialize(self.train_dataloader, self.model, - self.criterion, self.optimizer, - self.lr_scheduler) - self.forward_only = False + @property + def model(self): + return self._model + + @property + def optimizer(self): + return self._optimizer + + @property + def criterion(self): + return self._criterion + + @property + def schedule(self): + return self._schedule + + @property + def gradient_accumulation(self): + return self._grad_accum_size def handle_gradient(self): """Handles all-reduce operations of gradients across different parallel groups. @@ -99,72 +115,62 @@ class Engine: for handler in self._gradient_handlers: handler.handle_gradient() - def set_dataloader(self, data: DataLoader, train: bool = True): - """Sets dataloader in training or evaluation. - - :param data: Dataloader to be set - :param train: Set training dataloader if True, otherwise evaluation dataloader - :type data: DataLoader - :type train: bool - """ - if train: - self.train_dataloader = data - else: - self.test_dataloader = data - - def get_model(self): - """Returns the neural network model in the engine. - """ - return self.model - def get_optimizer(self): - """Returns optimizier in the engine. - """ - return self.optimizer - - def get_lr_scheduler(self): - """Returns the learning rate scheduler in the engine. - """ - return self.lr_scheduler - def train(self): """Sets the model to training mode. """ - self.forward_only = False - self.schedule.train(dataloader=self.train_dataloader, mode=True) + self.training = True + self._model.train() def eval(self): """Sets the model to evaluation mode. """ - self.forward_only = True - self.schedule.train(dataloader=self.test_dataloader, mode=False) + self.training = False + self._model.eval() - def is_train(self): - """Returns True if it is in training, otherwise False. - """ - return not self.forward_only - - def get_lr(self): - """Gets current learning rate. - """ - return self.schedule.get_lr() - - def step(self, return_loss=True): + def step(self, + data_iter, + is_last_iteration: bool = False, + return_loss=True): """A running step based on the schedule. Usually, it runs a training or evaluation over a batch of dataset. + :param data_iter: Data iterator of the dataset + :param is_last_iteration: If True, this iteration is the last iteration in the epoch :param return_loss: loss will be returned if True - :type return_loss: bool + :type data_iter: Iterator + :type is_last_iteration: bool, optional + :type return_loss: bool, optional :return: (output, lablel, loss) """ - self.schedule.zero_grad(forward_only=self.forward_only) + if self.training: + self._optimizer.zero_grad() - output, label, loss = self.schedule.forward_backward_step( - forward_only=self.forward_only, return_loss=return_loss) + # differentiate training and eval with grad accum + if self.training: + for i in range(self._grad_accum_size): + output, label, loss = self._schedule.forward_backward_step( + data_iter, self._model, self._criterion, self._optimizer, + forward_only=False, + grad_accum_size=self._grad_accum_size, + return_loss=return_loss) - if not self.forward_only: - # all reduce gradients - self.handle_gradient() + if i == self._grad_accum_size - 1: + # all reduce gradients + self.handle_gradient() + self._schedule.optimizer_step(self._model, self._optimizer, self._grad_clip) + else: + output, label, loss = self._schedule.forward_backward_step( + data_iter, self._model, self._criterion, self._optimizer, + forward_only=True, + grad_accum_size=1, + return_loss=return_loss) - self.schedule.step() + # consume the remaining dataset left out due to gradient accumulation + if is_last_iteration: + while True: + try: + _ = next(data_iter) + except StopIteration: + break return output, label, loss diff --git a/colossalai/engine/amp/__init__.py b/colossalai/engine/amp/__init__.py new file mode 100644 index 000000000..927d5cf09 --- /dev/null +++ b/colossalai/engine/amp/__init__.py @@ -0,0 +1,2 @@ +from .grad_scaler import GradScaler +from .amp_type import AMP_TYPE diff --git a/colossalai/engine/amp_type.py b/colossalai/engine/amp/amp_type.py similarity index 100% rename from colossalai/engine/amp_type.py rename to colossalai/engine/amp/amp_type.py diff --git a/colossalai/engine/amp/grad_scaler.py b/colossalai/engine/amp/grad_scaler.py new file mode 100644 index 000000000..7859d132d --- /dev/null +++ b/colossalai/engine/amp/grad_scaler.py @@ -0,0 +1,577 @@ +# modified from https://github.com/pytorch/pytorch/blob/master/torch/cuda/amp/grad_scaler.p +import torch +from collections import defaultdict, abc +import warnings +from enum import Enum +from typing import Any, Dict, List, Optional, Tuple +from colossalai.context import ParallelMode +import torch.distributed as dist +from colossalai.core import global_context as gpc + + +class _MultiDeviceReplicator(object): + """ + Lazily serves copies of a tensor to requested devices. Copies are cached per-device. + """ + + def __init__(self, master_tensor: torch.Tensor) -> None: + assert master_tensor.is_cuda or master_tensor.device.type == 'xla' + self.master = master_tensor + self._per_device_tensors: Dict[torch.device, torch.Tensor] = {} + + def get(self, device) -> torch.Tensor: + retval = self._per_device_tensors.get(device, None) + if retval is None: + retval = self.master.to( + device=device, non_blocking=True, copy=True) + self._per_device_tensors[device] = retval + return retval + + +# Defines default_factory for GradScaler's _per_optimizer_states defaultdict, +# as well as associated "enum" values. Prefers defining these at top level because +# - Lambdas can't be pickled, so we don't want to supply a lambda as the factory. +# - Defining READY, UNSCALED, STEPPED and _refresh_per_optimizer_state within GradScaler +# causes a circular reference, which we'd rather avoid. +class OptState(Enum): + READY = 0 + UNSCALED = 1 + STEPPED = 2 + + +def _refresh_per_optimizer_state(): + return {"stage": OptState.READY, "found_inf_per_device": {}} + + +class GradScaler(object): + _scale: Optional[torch.Tensor] + _grows_tracker: Optional[torch.Tensor] + _per_optimizer_states: Dict[int, Dict[str, Any]] + """ + An instance ``scaler`` of :class:`GradScaler` helps perform the steps of gradient scaling + conveniently. + + * ``scaler.scale(loss)`` multiplies a given loss by ``scaler``'s current scale factor. + * ``scaler.step(optimizer)`` safely unscales gradients and calls ``optimizer.step()``. + * ``scaler.update()`` updates ``scaler``'s scale factor. + + Example:: + + # Creates a GradScaler once at the beginning of training. + scaler = GradScaler() + + for epoch in epochs: + for input, target in data: + optimizer.zero_grad() + output = model(input) + loss = loss_fn(output, target) + + # Scales loss. Calls backward() on scaled loss to create scaled gradients. + scaler.scale(loss).backward() + + # scaler.step() first unscales gradients of the optimizer's params. + # If gradients don't contain infs/NaNs, optimizer.step() is then called, + # otherwise, optimizer.step() is skipped. + scaler.step(optimizer) + + # Updates the scale for next iteration. + scaler.update() + + See the :ref:`Automatic Mixed Precision examples` for usage + (along with autocasting) in more complex cases like gradient clipping, gradient accumulation, gradient penalty, + and multiple losses/optimizers. + + ``scaler`` dynamically estimates the scale factor each iteration. To minimize gradient underflow, + a large scale factor should be used. However, ``float16`` values can "overflow" (become inf or NaN) if + the scale factor is too large. Therefore, the optimal scale factor is the largest factor that can be used + without incurring inf or NaN gradient values. + ``scaler`` approximates the optimal scale factor over time by checking the gradients for infs and NaNs during every + ``scaler.step(optimizer)`` (or optional separate ``scaler.unscale_(optimizer)``, see :meth:`unscale_`). + + * If infs/NaNs are found, ``scaler.step(optimizer)`` skips the underlying ``optimizer.step()`` (so the params + themselves remain uncorrupted) and ``update()`` multiplies the scale by ``backoff_factor``. + + * If no infs/NaNs are found, ``scaler.step(optimizer)`` runs the underlying ``optimizer.step()`` as usual. + If ``growth_interval`` unskipped iterations occur consecutively, ``update()`` multiplies the scale by + ``growth_factor``. + + The scale factor often causes infs/NaNs to appear in gradients for the first few iterations as its + value calibrates. ``scaler.step`` will skip the underlying ``optimizer.step()`` for these + iterations. After that, step skipping should occur rarely (once every few hundred or thousand iterations). + + Args: + init_scale (float, optional, default=2.**16): Initial scale factor. + growth_factor (float, optional, default=2.0): Factor by which the scale is multiplied during + :meth:`update` if no inf/NaN gradients occur for ``growth_interval`` consecutive iterations. + backoff_factor (float, optional, default=0.5): Factor by which the scale is multiplied during + :meth:`update` if inf/NaN gradients occur in an iteration. + growth_interval (int, optional, default=2000): Number of consecutive iterations without inf/NaN gradients + that must occur for the scale to be multiplied by ``growth_factor``. + enabled (bool, optional, default=True): If ``False``, disables gradient scaling. :meth:`step` simply + invokes the underlying ``optimizer.step()``, and other methods become no-ops. + """ + + def __init__(self, + init_scale=2.**16, + growth_factor=2.0, + backoff_factor=0.5, + growth_interval=2000, + enabled=True): + if enabled and not torch.cuda.is_available(): + warnings.warn( + "torch.cuda.amp.GradScaler is enabled, but CUDA is not available. Disabling.") + self._enabled = False + else: + self._enabled = enabled + + if self._enabled: + assert growth_factor > 1.0, "The growth factor must be > 1.0." + assert backoff_factor < 1.0, "The backoff factor must be < 1.0." + + self._init_scale = init_scale + # self._scale will be lazily initialized during the first call to scale() + self._scale = None + self._growth_factor = growth_factor + self._backoff_factor = backoff_factor + self._growth_interval = growth_interval + self._init_growth_tracker = 0 + # self._growth_tracker will be lazily initialized during the first call to scale() + self._growth_tracker = None + self._per_optimizer_states = defaultdict( + _refresh_per_optimizer_state) + + def _check_scale_growth_tracker(self, funcname) -> Tuple[torch.Tensor, torch.Tensor]: + fix = "This may indicate your script did not use scaler.scale(loss or outputs) earlier in the iteration." + assert self._scale is not None, "Attempted {} but _scale is None. ".format( + funcname) + fix + assert self._growth_tracker is not None, "Attempted {} but _growth_tracker is None. ".format( + funcname) + fix + return (self._scale, self._growth_tracker) + + def _lazy_init_scale_growth_tracker(self, dev): + assert self._growth_tracker is None, "_growth_tracker initialized before _scale" + self._scale = torch.full( + (1,), self._init_scale, dtype=torch.float32, device=dev) + self._growth_tracker = torch.full( + (1,), self._init_growth_tracker, dtype=torch.int32, device=dev) + + def scale(self, outputs): + """ + Multiplies ('scales') a tensor or list of tensors by the scale factor. + + Returns scaled outputs. If this instance of :class:`GradScaler` is not enabled, outputs are returned + unmodified. + + Args: + outputs (Tensor or iterable of Tensors): Outputs to scale. + """ + if not self._enabled: + return outputs + + # Short-circuit for the common case. + if isinstance(outputs, torch.Tensor): + assert outputs.is_cuda or outputs.device.type == 'xla' + if self._scale is None: + self._lazy_init_scale_growth_tracker(outputs.device) + assert self._scale is not None + return outputs * self._scale.to(device=outputs.device, non_blocking=True) + + # Invoke the more complex machinery only if we're treating multiple outputs. + # holds a reference that can be overwritten by apply_scale + stash: List[_MultiDeviceReplicator] = [] + + def apply_scale(val): + if isinstance(val, torch.Tensor): + assert val.is_cuda or val.device.type == 'xla' + if len(stash) == 0: + if self._scale is None: + self._lazy_init_scale_growth_tracker(val.device) + assert self._scale is not None + stash.append(_MultiDeviceReplicator(self._scale)) + return val * stash[0].get(val.device) + elif isinstance(val, abc.Iterable): + iterable = map(apply_scale, val) + if isinstance(val, list) or isinstance(val, tuple): + return type(val)(iterable) + else: + return iterable + else: + raise ValueError( + "outputs must be a Tensor or an iterable of Tensors") + + return apply_scale(outputs) + + def _unscale_grads_(self, optimizer, inv_scale, found_inf, allow_fp16): + per_device_inv_scale = _MultiDeviceReplicator(inv_scale) + per_device_found_inf = _MultiDeviceReplicator(found_inf) + + # To set up _amp_foreach_non_finite_check_and_unscale_, split grads by device and dtype. + # There could be hundreds of grads, so we'd like to iterate through them just once. + # However, we don't know their devices or dtypes in advance. + + # https://stackoverflow.com/questions/5029934/defaultdict-of-defaultdict + # Google says mypy struggles with defaultdicts type annotations. + per_device_and_dtype_grads = defaultdict( + lambda: defaultdict(list)) # type: ignore[var-annotated] + with torch.no_grad(): + for group in optimizer.param_groups: + for param in group["params"]: + if param.grad is None: + continue + if (not allow_fp16) and param.grad.dtype == torch.float16: + raise ValueError( + "Attempting to unscale FP16 gradients.") + if param.grad.is_sparse: + # is_coalesced() == False means the sparse grad has values with duplicate indices. + # coalesce() deduplicates indices and adds all values that have the same index. + # For scaled fp16 values, there's a good chance coalescing will cause overflow, + # so we should check the coalesced _values(). + if param.grad.dtype is torch.float16: + param.grad = param.grad.coalesce() + to_unscale = param.grad._values() + else: + to_unscale = param.grad + + # TODO: is there a way to split by device and dtype without appending in the inner loop? + per_device_and_dtype_grads[to_unscale.device][to_unscale.dtype].append( + to_unscale) + + for device, per_dtype_grads in per_device_and_dtype_grads.items(): + for grads in per_dtype_grads.values(): + torch._amp_foreach_non_finite_check_and_unscale_(grads, + per_device_found_inf.get( + device), + per_device_inv_scale.get(device)) + # For tensor parallel paramters it should be all-reduced over tensor parallel process group + if gpc.is_initialized(ParallelMode.TENSOR) and gpc.get_world_size(ParallelMode.TENSOR) > 1: + for tensor in per_device_found_inf._per_device_tensors.values(): + dist.all_reduce(tensor, op=dist.ReduceOp.MAX, + group=gpc.get_group(ParallelMode.TENSOR)) + return per_device_found_inf._per_device_tensors + + def unscale_(self, optimizer): + """ + Divides ("unscales") the optimizer's gradient tensors by the scale factor. + + :meth:`unscale_` is optional, serving cases where you need to + :ref:`modify or inspect gradients` + between the backward pass(es) and :meth:`step`. + If :meth:`unscale_` is not called explicitly, gradients will be unscaled automatically during :meth:`step`. + + Simple example, using :meth:`unscale_` to enable clipping of unscaled gradients:: + + ... + scaler.scale(loss).backward() + scaler.unscale_(optimizer) + torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm) + scaler.step(optimizer) + scaler.update() + + Args: + optimizer (torch.optim.Optimizer): Optimizer that owns the gradients to be unscaled. + + .. note:: + :meth:`unscale_` does not incur a CPU-GPU sync. + + .. warning:: + :meth:`unscale_` should only be called once per optimizer per :meth:`step` call, + and only after all gradients for that optimizer's assigned parameters have been accumulated. + Calling :meth:`unscale_` twice for a given optimizer between each :meth:`step` triggers a RuntimeError. + + .. warning:: + :meth:`unscale_` may unscale sparse gradients out of place, replacing the ``.grad`` attribute. + """ + if not self._enabled: + return + + self._check_scale_growth_tracker("unscale_") + + optimizer_state = self._per_optimizer_states[id(optimizer)] + + if optimizer_state["stage"] is OptState.UNSCALED: + raise RuntimeError( + "unscale_() has already been called on this optimizer since the last update().") + elif optimizer_state["stage"] is OptState.STEPPED: + raise RuntimeError("unscale_() is being called after step().") + + # FP32 division can be imprecise for certain compile options, so we carry out the reciprocal in FP64. + assert self._scale is not None + inv_scale = self._scale.double().reciprocal().float() + found_inf = torch.full( + (1,), 0.0, dtype=torch.float32, device=self._scale.device) + + optimizer_state["found_inf_per_device"] = self._unscale_grads_( + optimizer, inv_scale, found_inf, False) + optimizer_state["stage"] = OptState.UNSCALED + + def _maybe_opt_step(self, optimizer, optimizer_state, *args, **kwargs): + retval = None + if not sum(v.item() for v in optimizer_state["found_inf_per_device"].values()): + retval = optimizer.step(*args, **kwargs) + return retval + + def step(self, optimizer, *args, **kwargs): + """ + :meth:`step` carries out the following two operations: + + 1. Internally invokes ``unscale_(optimizer)`` (unless :meth:`unscale_` was explicitly called for ``optimizer`` + earlier in the iteration). As part of the :meth:`unscale_`, gradients are checked for infs/NaNs. + 2. If no inf/NaN gradients are found, invokes ``optimizer.step()`` using the unscaled + gradients. Otherwise, ``optimizer.step()`` is skipped to avoid corrupting the params. + + ``*args`` and ``**kwargs`` are forwarded to ``optimizer.step()``. + + Returns the return value of ``optimizer.step(*args, **kwargs)``. + + Args: + optimizer (torch.optim.Optimizer): Optimizer that applies the gradients. + args: Any arguments. + kwargs: Any keyword arguments. + + .. warning:: + Closure use is not currently supported. + """ + if (not self._enabled): + return optimizer.step(*args, **kwargs) + + if "closure" in kwargs: + raise RuntimeError( + "Closure use is not currently supported if GradScaler is enabled.") + + self._check_scale_growth_tracker("step") + + optimizer_state = self._per_optimizer_states[id(optimizer)] + + if optimizer_state["stage"] is OptState.STEPPED: + raise RuntimeError( + "step() has already been called since the last update().") + + retval = None + + if (hasattr(optimizer, "_step_supports_amp_scaling") and optimizer._step_supports_amp_scaling): + # This optimizer has customized scale-handling logic, so we can call optimizer.step() directly. + # The contract with custom optimizers is that their step() should accept an additional, + # optional grad_scaler kwarg. We append self to the kwargs so the custom optimizer has full information: + # it can query its own state, invoke unscale_ on itself, etc + retval = optimizer.step(*args, **dict(kwargs, grad_scaler=self)) + optimizer_state["stage"] = OptState.STEPPED + return retval + + if optimizer_state["stage"] is OptState.READY: + self.unscale_(optimizer) + + assert len(optimizer_state["found_inf_per_device"] + ) > 0, "No inf checks were recorded for this optimizer." + + retval = self._maybe_opt_step( + optimizer, optimizer_state, *args, **kwargs) + + optimizer_state["stage"] = OptState.STEPPED + + return retval + + def update(self, new_scale=None): + """ + Updates the scale factor. + + If any optimizer steps were skipped the scale is multiplied by ``backoff_factor`` + to reduce it. If ``growth_interval`` unskipped iterations occurred consecutively, + the scale is multiplied by ``growth_factor`` to increase it. + + Passing ``new_scale`` sets the new scale value manually. (``new_scale`` is not + used directly, it's used to fill GradScaler's internal scale tensor. So if + ``new_scale`` was a tensor, later in-place changes to that tensor will not further + affect the scale GradScaler uses internally.) + + Args: + new_scale (float or :class:`torch.cuda.FloatTensor`, optional, default=None): New scale factor. + + .. warning:: + :meth:`update` should only be called at the end of the iteration, after ``scaler.step(optimizer)`` has + been invoked for all optimizers used this iteration. + """ + if not self._enabled: + return + + _scale, _growth_tracker = self._check_scale_growth_tracker("update") + + if new_scale is not None: + # Accept a new user-defined scale. + if isinstance(new_scale, float): + self._scale.fill_(new_scale) # type: ignore[union-attr] + else: + reason = "new_scale should be a float or a 1-element torch.cuda.FloatTensor with requires_grad=False." + # type: ignore[attr-defined] + assert isinstance(new_scale, torch.cuda.FloatTensor), reason + assert new_scale.numel() == 1, reason + assert new_scale.requires_grad is False, reason + self._scale.copy_(new_scale) # type: ignore[union-attr] + else: + # Consume shared inf/nan data collected from optimizers to update the scale. + # If all found_inf tensors are on the same device as self._scale, this operation is asynchronous. + found_infs = [found_inf.to(device=_scale.device, non_blocking=True) + for state in self._per_optimizer_states.values() + for found_inf in state["found_inf_per_device"].values()] + + assert len( + found_infs) > 0, "No inf checks were recorded prior to update." + + found_inf_combined = found_infs[0] + if len(found_infs) > 1: + for i in range(1, len(found_infs)): + found_inf_combined += found_infs[i] + + torch._amp_update_scale_(_scale, + _growth_tracker, + found_inf_combined, + self._growth_factor, + self._backoff_factor, + self._growth_interval) + + # To prepare for next iteration, clear the data collected from optimizers this iteration. + self._per_optimizer_states = defaultdict(_refresh_per_optimizer_state) + + def _get_scale_async(self): + return self._scale + + def get_scale(self): + """ + Returns a Python float containing the current scale, or 1.0 if scaling is disabled. + + .. warning:: + :meth:`get_scale` incurs a CPU-GPU sync. + """ + if self._enabled: + return self._init_scale if self._scale is None else self._get_scale_async().item() + else: + return 1.0 + + def get_growth_factor(self): + r""" + Returns a Python float containing the scale growth factor. + """ + return self._growth_factor + + def set_growth_factor(self, new_factor): + r""" + Args: + new_scale (float): Value to use as the new scale growth factor. + """ + self._growth_factor = new_factor + + def get_backoff_factor(self): + r""" + Returns a Python float containing the scale backoff factor. + """ + return self._backoff_factor + + def set_backoff_factor(self, new_factor): + r""" + Args: + new_scale (float): Value to use as the new scale backoff factor. + """ + self._backoff_factor = new_factor + + def get_growth_interval(self): + r""" + Returns a Python int containing the growth interval. + """ + return self._growth_interval + + def set_growth_interval(self, new_interval): + r""" + Args: + new_interval (int): Value to use as the new growth interval. + """ + self._growth_interval = new_interval + + def _get_growth_tracker(self): + if self._enabled: + return self._init_growth_tracker if self._growth_tracker is None else self._growth_tracker.item() + else: + return 0 + + def is_enabled(self): + r""" + Returns a bool indicating whether this instance is enabled. + """ + return self._enabled + + def state_dict(self): + r""" + Returns the state of the scaler as a :class:`dict`. It contains five entries: + + * ``"scale"`` - a Python float containing the current scale + * ``"growth_factor"`` - a Python float containing the current growth factor + * ``"backoff_factor"`` - a Python float containing the current backoff factor + * ``"growth_interval"`` - a Python int containing the current growth interval + * ``"_growth_tracker"`` - a Python int containing the number of recent consecutive unskipped steps. + + If this instance is not enabled, returns an empty dict. + + .. note:: + If you wish to checkpoint the scaler's state after a particular iteration, :meth:`state_dict` + should be called after :meth:`update`. + """ + return {"scale": self.get_scale(), + "growth_factor": self._growth_factor, + "backoff_factor": self._backoff_factor, + "growth_interval": self._growth_interval, + "_growth_tracker": self._get_growth_tracker()} if self._enabled else {} + + def load_state_dict(self, state_dict): + r""" + Loads the scaler state. If this instance is disabled, :meth:`load_state_dict` is a no-op. + + Args: + state_dict(dict): scaler state. Should be an object returned from a call to :meth:`state_dict`. + """ + if not self._enabled: + return + + if len(state_dict) == 0: + raise RuntimeError("The source state dict is empty, possibly because it was saved " + "from a disabled instance of GradScaler.") + + self._init_scale = state_dict["scale"] + if self._scale is not None: + self._scale.fill_(state_dict["scale"]) + self._growth_factor = state_dict["growth_factor"] + self._backoff_factor = state_dict["backoff_factor"] + self._growth_interval = state_dict["growth_interval"] + self._init_growth_tracker = state_dict["_growth_tracker"] + if self._growth_tracker is not None: + self._growth_tracker.fill_(state_dict["_growth_tracker"]) + + def __getstate__(self): + state = self.__dict__.copy() + if self._enabled: + assert len(self._per_optimizer_states) == 0, "A GradScaler instance may only be pickled at the beginning "\ + "of an iteration, or at the end after scaler.update()." + # Pickling _scale and _growth_tracker Tensors directly triggers + # "warnings.warn("pickle support for Storage will be removed in 1.5..." + # so instead, we set the unpickled instance up to reinitialize them lazily. + state['_init_scale'] = self.get_scale() + state['_init_growth_tracker'] = self._get_growth_tracker() + state['_scale'] = None + state['_growth_tracker'] = None + return state + + def __setstate__(self, state): + self.__dict__.update(state) + + def _check_inf_per_device(self, optimizer): + _scale, _ = self._check_scale_growth_tracker("_check_inf_per_device") + + dummy_inv_scale = torch.full( + (1,), 1.0, dtype=torch.float32, device=_scale.device) + found_inf = torch.full( + (1,), 0.0, dtype=torch.float32, device=_scale.device) + + self._per_optimizer_states[id(optimizer)]["found_inf_per_device"] = \ + self._unscale_grads_(optimizer, dummy_inv_scale, found_inf, True) + + return self._per_optimizer_states[id(optimizer)]["found_inf_per_device"] + + def _found_inf_per_device(self, optimizer): + return self._per_optimizer_states[id(optimizer)]["found_inf_per_device"] diff --git a/colossalai/engine/schedule/_base_schedule.py b/colossalai/engine/schedule/_base_schedule.py index c64031c09..0583ccbf3 100644 --- a/colossalai/engine/schedule/_base_schedule.py +++ b/colossalai/engine/schedule/_base_schedule.py @@ -5,125 +5,85 @@ from abc import ABC, abstractmethod import torch +from colossalai.core import global_context as gpc from colossalai.logging import get_global_dist_logger from colossalai.utils import get_current_device class BaseSchedule(ABC): """A basic helper class to control the process of training or evaluation. + It mainly composes of forward_backward_step for gradient backward and + optimizer_step for parameters update. + For the convenience to enable FP16, we aggreate all codes that contain the + control of FP16 in class schedule. """ + def __init__(self): - self.initialized = False self.logger = get_global_dist_logger() - @property - @abstractmethod - def num_steps(self): - """The number of batches in training or evaluation. - """ - pass - - def initialize(self, - dataloader=None, - model=None, - criterion=None, - optimizer=None, - lr_scheduler=None): - """Initializes the schedule and set parameters before running. - - :param dataloader: DataLoader in training or evaluation - :param model: The neural network model - :param criterion: Criterion for calculating loss - :param optimizer: Optimizer for updating the parameters - :param lr_scheduler: Learning rate scheduler in the process - """ - self.dataloader = dataloader - assert model is not None, "Schedule requires a model" - self.model = model - assert criterion is not None, "Schedule requires a criterion" - self.criterion = criterion - assert optimizer is not None, "Schedule requires an optimizer" - self.optimizer = optimizer - self.lr_scheduler = lr_scheduler - self.initialized = True - - def check_initialized(self): - """Checks whether the schedule is initialized. - """ - assert self.initialized, \ - 'Schedule is not initialized. Call schedule.initialize(...) before using it.' - - def load_batch(self): - """Loads a batch of dataset. It returns the data and labels which are - already in the same GPU as where the model's. - - :return: (data, label) - :rtype: (Tensor, Tensor) - """ - self.check_initialized() - if self.data_iter is None: - raise RuntimeError('Dataloader is not defined.') - data, label = next(self.data_iter) - return self._move_to_device(data), self._move_to_device(label) + @staticmethod + def _move_tensor(element): + if torch.is_tensor(element): + if not element.is_cuda: + return element.to(get_current_device()).detach() + return element def _move_to_device(self, data): - if isinstance(data, ( - tuple, - list, - )): - data = tuple([ - d.to(get_current_device()).detach() for d in data - if torch.is_tensor(d) - ]) + if isinstance(data, (tuple, list)): + data = tuple([self._move_tensor(d) for d in data]) elif torch.is_tensor(data): data = data.to(get_current_device()).detach() return data - def train(self, dataloader=None, mode=True): - """Sets the dataloader to be used and turn the model to - training or evaluation mode. + def load_batch(self, data_iter): + """Loads a batch from data iterator. It returns the data and labels which are + already in the same GPU as where the model's. - :param dataloader: Dataloader to be used - :param mode: If True, the model will set as training mode. Otherwise, evaluation mode. + :return: (data, label) + :rtype: (Tensor, Tensor) """ - self.check_initialized() - if mode: - self.model.train() - else: - self.model.eval() - if dataloader is not None: - self.dataloader = dataloader - self.data_iter = iter(dataloader) + if data_iter is None: + raise RuntimeError('Dataloader is not defined.') + data, label = next(data_iter) + return self._move_to_device(data), self._move_to_device(label) - def zero_grad(self, forward_only=False): - """Cleans gradients with the optimizer. - """ - if not forward_only: - self.check_initialized() - self.optimizer.zero_grad() + def initialize(self, model, optimizer): + """Initializes the model and the optimizer before training. + This is often used in FP16 training. - def get_lr(self): - """Returns the current learning rate. + :param model: The neural network model + :param optimizer: Optimizer for updating the parameters """ - if self.lr_scheduler is not None: - return self.lr_scheduler.get_lr()[0] - else: - return self.optimizer.param_groups[0]['lr'] - - def step(self): - """Updates the parameters and learning rate with the optimizer. - """ - self.check_initialized() - self.optimizer.step() - # update lr scheduler - if self.lr_scheduler is not None: - self.lr_scheduler.step() + return model, optimizer @abstractmethod - def forward_backward_step(self, forward_only=False, return_loss=True): + def forward_backward_step(self, + data_iter, + model, + criterion, + optimizer=None, + forward_only=False, + grad_accum_size: int = 1, + return_loss=True): """The process function over a batch of dataset for training or evaluation. - :param forward_only: If True, the process won't include backward. - :param return_loss: If False, the loss won't be returned. + :param data_iter: Data iterator of the dataset + :param model: Model used in training or evaluation + :param optimizer: Optimizer used in training or evaluation + :param criterion: Loss function + :param forward_only: If True, the process won't include backward + :param grad_accum_size: Steps of gradient accumulation + :param return_loss: If False, the loss won't be returned + """ + pass + + @abstractmethod + def optimizer_step(self, model, optimizer, grad_clipping: float = 0.0): + """Updates the parameters with the optimizer. + + :param model: The neural network model + :param optimizer: Optimizer for updating the parameters + :param grad_clipping: The norm of gradient clipping + :type grad_clipping: float, optional """ pass diff --git a/colossalai/engine/schedule/_no_pipeline.py b/colossalai/engine/schedule/_no_pipeline.py index 3ab1fa2d3..4f38e6cda 100644 --- a/colossalai/engine/schedule/_no_pipeline.py +++ b/colossalai/engine/schedule/_no_pipeline.py @@ -4,19 +4,24 @@ try: import apex.amp as apex_amp except: - print('apex is required for mixed precision training') + pass + try: import torch.cuda.amp as torch_amp except: - print('PyTorch amp is not supported with the current PyTorch version') + pass + +from typing import Iterable + +import torch.nn as nn +from torch.optim import Optimizer -from colossalai.context import ParallelMode -from colossalai.core import global_context as gpc -from colossalai.engine.amp_type import AMP_TYPE from colossalai.nn import (ZeroRedundancyOptimizer_Level_2, ZeroRedundancyOptimizer_Level_3) -from ._utils import convert_to_fp16 +from colossalai.nn.optimizer._utils import clip_grad_norm_fp32 from ._base_schedule import BaseSchedule +from ._utils import convert_to_fp16, convert_to_fp32 +from ..amp import AMP_TYPE, GradScaler class NoPipelineSchedule(BaseSchedule): @@ -30,6 +35,7 @@ class NoPipelineSchedule(BaseSchedule): :type amp_type: AMP_TYPE :type amp_config: dict """ + def __init__( self, amp_type: AMP_TYPE = None, @@ -41,12 +47,6 @@ class NoPipelineSchedule(BaseSchedule): assert amp_type is None or isinstance(amp_type, AMP_TYPE), \ 'unrecognised value for argument fp16, it can only be None, torch or apex' - # LSG: check compatibility - # LSG: torch.cuda.amp and apex.amp cannot be used for tensor parallel - if gpc.is_initialized(ParallelMode.TENSOR) and gpc.get_world_size( - ParallelMode.TENSOR) > 1: - assert amp_type != AMP_TYPE.TORCH and amp_type != AMP_TYPE.APEX, \ - 'You can only AMP_TYPE.PARALLEL for tensor parallel training' self.use_zero_level_2_3 = False if amp_type is not None: @@ -79,107 +79,110 @@ class NoPipelineSchedule(BaseSchedule): self.fp16 = False self.amp_type = None - @property - def num_steps(self): - return len(self.dataloader) - - def initialize(self, - dataloader, - model, - criterion, - optimizer, - lr_scheduler=None): - super().initialize(dataloader, - model, - criterion, - optimizer, - lr_scheduler=lr_scheduler) - if isinstance(self.optimizer, (ZeroRedundancyOptimizer_Level_2, - ZeroRedundancyOptimizer_Level_3)): + def initialize(self, model: nn.Module, optimizer: Optimizer): + if isinstance(optimizer, (ZeroRedundancyOptimizer_Level_2, + ZeroRedundancyOptimizer_Level_3)): self.use_zero_level_2_3 = True - assert self.amp_type != AMP_TYPE.PARALLEL, 'ZeRO Level 2 and 3 are mutually exclusive with AMP_TYPE.PARALLEL' + assert self.amp_type != AMP_TYPE.PARALLEL, \ + 'ZeRO Level 2 and 3 are mutually exclusive with AMP_TYPE.PARALLEL' if self.fp16: if self.amp_type == AMP_TYPE.TORCH: - self._torch_amp_scaler = torch_amp.GradScaler(**self.amp_cfg) + self._torch_amp_scaler = GradScaler(**self.amp_cfg) elif self.amp_type == AMP_TYPE.APEX: - self.model, self.optimizer = apex_amp.initialize( - self.model, self.optimizer, **self.amp_cfg) + model, optimizer = apex_amp.initialize(model, optimizer, **self.amp_cfg) - def forward_backward_step(self, forward_only=False, return_loss=True): + return model, optimizer + + def forward_backward_step(self, + data_iter: Iterable, + model: nn.Module, + criterion: nn.modules.loss._Loss, + optimizer: Optimizer = None, + forward_only: bool = False, + grad_accum_size: int = 1, + return_loss: bool = True): """The process function that loads loads a batch of dataset and feeds it to the model. The returned labels and loss will None if :attr:`return_loss` is False. + :param data_iter: Data iterator of the dataloader, e.g. iter(dataloader) + :param model: Model for training and inference + :param criterion: Loss function for training + :param optimizer: Optimizer used for training + :param forward_only: If True, the model is run for the forward pass, else back propagation will be executed + :param grad_accum_size: The number of iterations for gradient accumulation + :param return_loss: Loss will be returned if True + :type data_iter: Iterator + :type model: torch.nn.Module + :type criterion: torch.nn.modules.loss._Loss + :type optimizer: torch.optim.Optimizer + :type forward_only: bool, optional + :type grad_accum_size: int + :type return_loss: bool, optional :return: (output, label, loss) """ assert forward_only or return_loss, \ 'The argument \'return_loss\' has to be True when \'forward_only\' is False, but got False.' - data, label = self.load_batch() + data, label = self.load_batch(data_iter) loss = None - # LSG: leave for debug, make sure dataloader is deterministic - # if forward_only: - # img = data[0] - # rank = gpc.get_local_rank(ParallelMode.DATA) - # world_size = gpc.get_world_size(ParallelMode.DATA) - # group = gpc.get_group(ParallelMode.DATA) - # input_list = [img.clone() for _ in range(world_size)] - # output_list = [torch.empty_like(img) for _ in range(world_size)] - # output_list[rank] = img.clone() - # dist.all_to_all(output_tensor_list=output_list, input_tensor_list=input_list, group=group) - # assert torch.equal(output_list[0], output_list[1]) # and torch.equal(output_list[1], output_list[2]) - # forward if self.fp16 and self.amp_type == AMP_TYPE.TORCH: with torch_amp.autocast(): - output = self.model(*data) + output = model(*data) if not isinstance(output, (tuple, list)): output = (output,) if return_loss: - loss = self.criterion(*output, *label) + loss = criterion(*output, *label) else: if self.use_zero_level_2_3 or self.amp_type == AMP_TYPE.PARALLEL: data = convert_to_fp16(data) - output = self.model(*data) + output = model(*data) + + if self.use_zero_level_2_3 or self.amp_type == AMP_TYPE.PARALLEL: + output = convert_to_fp32(output) + if not isinstance(output, (tuple, list)): output = (output,) if return_loss: - loss = self.criterion(*output, *label) + loss = criterion(*output, *label) + + loss /= grad_accum_size if not forward_only: # backward if self.use_zero_level_2_3: - self.optimizer.backward(loss) + optimizer.backward(loss) elif self.fp16: if self.amp_type == AMP_TYPE.APEX: - with apex_amp.scale_loss(loss, - self.optimizer) as scaled_loss: + with apex_amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() elif self.amp_type == AMP_TYPE.TORCH: self._torch_amp_scaler.scale(loss).backward() elif self.amp_type == AMP_TYPE.PARALLEL: - loss = self.optimizer.scale_loss(loss) + loss = optimizer.scale_loss(loss) loss.backward() # scale back to display the original value in logs - loss.div_(self.optimizer.grad_scaler.scale) + loss.div_(optimizer.grad_scaler.scale) else: loss.backward() if return_loss: - return output, label, loss + return output, label, loss * grad_accum_size else: return output, None, None - def step(self): + def optimizer_step(self, model: nn.Module, optimizer: Optimizer, grad_clipping: float = 0.0): # step optimizer if self.fp16 and self.amp_type == AMP_TYPE.TORCH: - self._torch_amp_scaler.step(self.optimizer) + if grad_clipping > 0.0: + self._torch_amp_scaler.unscale_(optimizer) + clip_grad_norm_fp32(model.parameters(), grad_clipping) + self._torch_amp_scaler.step(optimizer) self._torch_amp_scaler.update() else: - self.optimizer.step() - - # update lr scheduler - if self.lr_scheduler is not None: - self.lr_scheduler.step() + if not self.fp16 and not self.use_zero_level_2_3 and grad_clipping > 0.0: + clip_grad_norm_fp32(model.parameters(), grad_clipping) + optimizer.step() diff --git a/colossalai/engine/schedule/_pipeline.py b/colossalai/engine/schedule/_pipeline.py index 0b477c0d5..6defea93d 100644 --- a/colossalai/engine/schedule/_pipeline.py +++ b/colossalai/engine/schedule/_pipeline.py @@ -15,7 +15,7 @@ from colossalai.nn import (ZeroRedundancyOptimizer_Level_2, from colossalai.utils import get_current_device from ._base_schedule import BaseSchedule from ._utils import convert_to_fp16 -from ..amp_type import AMP_TYPE +from ..amp import AMP_TYPE def squeeze(x: Union[Tensor, tuple, list]): @@ -93,12 +93,11 @@ class PipelineSchedule(BaseSchedule): ) # Pipeline schedule just puts data in memory - def load_batch(self): - self.check_initialized() - if self.data_iter is None: + def load_batch(self, data_iter): + if data_iter is None: raise RuntimeError('Dataloader is not defined.') self.batch_pos = 0 - data, label = next(self.data_iter) + data, label = next(data_iter) self.batch_data, self.batch_label = \ self._move_to_device(data), self._move_to_device(label) batch_size = self.batch_data.shape[0] @@ -117,23 +116,8 @@ class PipelineSchedule(BaseSchedule): self.batch_pos += self.microbatch_size return (data,), (label,) - @property - def num_steps(self): - return len(self.dataloader) - - def initialize(self, - dataloader, - model, - criterion, - optimizer, - lr_scheduler=None): - super().initialize(dataloader, - model, - criterion, - optimizer, - lr_scheduler=lr_scheduler) - if isinstance(self.optimizer, (ZeroRedundancyOptimizer_Level_2, - ZeroRedundancyOptimizer_Level_3)): + def initialize(self, model, optimizer): + if isinstance(optimizer, (ZeroRedundancyOptimizer_Level_2, ZeroRedundancyOptimizer_Level_3)): raise TypeError( "Pipeline schedule is currently not compatible with ZeRO Level 2 and Level 3" ) @@ -145,7 +129,8 @@ class PipelineSchedule(BaseSchedule): 'default tensor dtype is set to torch.half for fp16 training', ranks=[0]) - def forward_step(self, input_tensor, return_tensors, return_loss=True): + def forward_step(self, model, criterion, input_tensor, return_tensors, + grad_accum_size, return_loss=True): """Forward step for passed-in model. If it is the first stage, the input tensor is obtained from data_iterator, otherwise the passed-in input_tensor is used. Returns output tensor. This is a helper function and can be ignored by users. @@ -156,14 +141,14 @@ class PipelineSchedule(BaseSchedule): if self.amp_type == AMP_TYPE.PARALLEL: input_tensor = convert_to_fp16(input_tensor) input_tensor = squeeze(input_tensor) - output_tensor = self.model(input_tensor) + output_tensor = model(input_tensor) output_tensor = squeeze(output_tensor) if gpc.is_last_rank(ParallelMode.PIPELINE): if return_loss: input_tensor, label = self.load_micro_batch() - loss_reduced = self.criterion(output_tensor, * - label) / self.num_microbatches + loss_reduced = criterion(output_tensor, *label) \ + / (self.num_microbatches * grad_accum_size) return_tensors.append( tuple((output_tensor, label[0], loss_reduced))) return loss_reduced @@ -174,7 +159,7 @@ class PipelineSchedule(BaseSchedule): else: return output_tensor - def backward_step(self, input_tensor, output_tensor, output_tensor_grad): + def backward_step(self, optimizer, input_tensor, output_tensor, output_tensor_grad): """Backward step through the passed-in output tensor. If it is the last stage, the output_tensor_grad is None, otherwise it is the gradients with respect to stage's output tensor. Returns the gradients with respect to the input tensor (None if first stage). @@ -187,7 +172,7 @@ class PipelineSchedule(BaseSchedule): # Backward pass. if output_tensor_grad is None and self.amp_type == AMP_TYPE.PARALLEL: - output_tensor = self.optimizer.scale_loss(output_tensor) + output_tensor = optimizer.scale_loss(output_tensor) torch.autograd.backward(output_tensor, grad_tensors=output_tensor_grad) # Collect the grad of the input_tensor. @@ -197,17 +182,24 @@ class PipelineSchedule(BaseSchedule): return input_tensor_grad - def forward_backward_step(self, forward_only=True, return_loss=True): + def forward_backward_step(self, + data_iter, + model, + criterion, + optimizer=None, + forward_only=False, + grad_accum_size: int = 1, + return_loss=True): """Runs non-interleaved 1F1B schedule, with communication between pipeline stages. Returns a tuple with losses if the last stage, an empty tuple otherwise. - + :return: (output, label, loss) """ assert forward_only or return_loss, \ 'The argument \'return_loss\' has to be True when \'forward_only\' is False, but got False.' - self.load_batch() + self.load_batch(data_iter) num_warmup_microbatches = \ (gpc.get_world_size(ParallelMode.PIPELINE) - gpc.get_local_rank(ParallelMode.PIPELINE) - 1) @@ -233,9 +225,11 @@ class PipelineSchedule(BaseSchedule): if not gpc.is_first_rank(ParallelMode.PIPELINE): ft_shape = recv_tensor_meta(ft_shape) input_tensor = recv_forward(ft_shape) - output_tensor = self.forward_step(input_tensor, - return_tensors, - return_loss=return_loss) + output_tensor = self.forward_step( + model, criterion, + input_tensor, return_tensors, + grad_accum_size, return_loss=return_loss + ) if not gpc.is_last_rank(ParallelMode.PIPELINE): bt_shape = output_tensor.shape fs_checker = send_tensor_meta(output_tensor, fs_checker) @@ -257,9 +251,11 @@ class PipelineSchedule(BaseSchedule): for i in range(num_microbatches_remaining): last_iteration = (i == (num_microbatches_remaining - 1)) - output_tensor = self.forward_step(input_tensor, - return_tensors, - return_loss=return_loss) + output_tensor = self.forward_step( + model, criterion, + input_tensor, return_tensors, + grad_accum_size, return_loss=return_loss + ) if forward_only: send_forward(output_tensor) @@ -279,9 +275,11 @@ class PipelineSchedule(BaseSchedule): input_tensor = input_tensors.pop(0) output_tensor = output_tensors.pop(0) - input_tensor_grad = self.backward_step(input_tensor, - output_tensor, - output_tensor_grad) + input_tensor_grad = self.backward_step( + optimizer, + input_tensor, output_tensor, + output_tensor_grad + ) if last_iteration: input_tensor = None @@ -298,9 +296,11 @@ class PipelineSchedule(BaseSchedule): output_tensor_grad = recv_backward(bt_shape) - input_tensor_grad = self.backward_step(input_tensor, - output_tensor, - output_tensor_grad) + input_tensor_grad = self.backward_step( + optimizer, + input_tensor, output_tensor, + output_tensor_grad + ) send_backward(input_tensor_grad) @@ -309,8 +309,11 @@ class PipelineSchedule(BaseSchedule): output, label, loss = tuple(map(list, zip(*return_tensors))) return (torch.cat(output, dim=0), torch.cat(label, dim=0), - sum(loss)) + sum(loss) * grad_accum_size) else: return tuple((torch.cat(return_tensors, dim=0), None, None)) else: return tuple((None, None, None)) + + def optimizer_step(self, model, optimizer, grad_clipping: float = 0.0): + optimizer.step() diff --git a/colossalai/engine/schedule/_utils.py b/colossalai/engine/schedule/_utils.py index 9c4a2a19b..cdfd0246c 100644 --- a/colossalai/engine/schedule/_utils.py +++ b/colossalai/engine/schedule/_utils.py @@ -14,3 +14,14 @@ def convert_to_fp16(data: Union[Tensor, List[Tensor]]): else: raise TypeError(f"Expected argument 'data' to be a Tensor or a list/tuple of Tensor, but got {type(data)}") return ret + + +def convert_to_fp32(data: Union[Tensor, List[Tensor]]): + if isinstance(data, Tensor): + ret = data.float() + elif isinstance(data, (list, tuple)): + ret = [val.float() for val in data] + else: + raise TypeError(f"Expected argument 'data' to be a Tensor or a list/tuple of Tensor, but got {type(data)}") + return ret + diff --git a/colossalai/initialize.py b/colossalai/initialize.py index 35e8095b6..6806d86eb 100644 --- a/colossalai/initialize.py +++ b/colossalai/initialize.py @@ -6,18 +6,20 @@ import pprint import random from pathlib import Path from typing import Callable, Iterable, Optional, Union +from typing import Tuple import numpy as np import torch from torch.utils.data import DataLoader from colossalai.engine import AMP_TYPE, NoPipelineSchedule, PipelineSchedule +from colossalai.engine import Engine from colossalai.logging import get_global_dist_logger, init_global_dist_logger from colossalai.nn import DataParallelSampler from colossalai.nn.model.base_model import BaseModel from .builder import (ModelInitializer, build_dataset, build_loss, - build_lr_scheduler, build_model, build_optimizer, - build_optimizer_wrapper) + build_model, build_optimizer, + build_optimizer_wrapper, build_schedule) from .context import Config, ParallelMode from .core import global_context as gpc from .utils import get_current_device, sync_model_param_in_dp @@ -182,7 +184,7 @@ def initialize(config: Union[str, dict] = None, backend: str = None, train_dataloader: Optional[Union[Iterable, Callable]] = None, test_dataloader: Optional[Union[Iterable, Callable]] = None, - ): + ) -> Tuple[Engine, DataLoader, DataLoader]: '''Core function that initializes distributed environment, logger, cudnn, data, model, loss function, optimizer, and lr_scheduler(their configs are in gpc.config). :param config: config file or config file path are both acceptable @@ -201,7 +203,7 @@ def initialize(config: Union[str, dict] = None, :type train_dataloader: Optional[Union[Iterable, Callable]], optional :param test_dataloader: If None, the config is used to build a dataloder; Else, it should be a dataloader object or a function with no arguments which can build a dataloader, defaults to None :type test_dataloader: Optional[Union[Iterable, Callable]], optional - :return: (model, train_dataloader, test_dataloader, criterion, optimizer, schedule, lr_scheduler) + :return: (engine, train_dataloader, test_dataloader, criterion) :rtype: tuple ''' # initialize distributed environment @@ -337,21 +339,7 @@ def initialize(config: Union[str, dict] = None, optimizer = build_optimizer_wrapper(fp16_cfg, optimizer) logger.info('Optimizer is created', ranks=[0]) - lr_scheduler = None - if hasattr(gpc.config, 'lr_scheduler'): - if hasattr(gpc.config, 'num_steps'): - total_steps = gpc.config.num_steps - elif hasattr(gpc.config, 'num_epochs'): - total_steps = int(gpc.config.num_epochs * len(train_dataloader)) - else: - raise Exception( - 'Please specify training stopping criterion num_steps or num_epochs in your configuration.' - ) - lr_scheduler = build_lr_scheduler(gpc.config.lr_scheduler, optimizer, - total_steps, len(train_dataloader)) - logger.info('Learning rate scheduler is created', ranks=[0]) - - # pipeline or no pipeline schedule + # build schedule and engine if hasattr(gpc.config, 'fp16'): amp_type = gpc.config.fp16.mode amp_cfg = gpc.config.fp16.copy() @@ -360,12 +348,32 @@ def initialize(config: Union[str, dict] = None, amp_type = None amp_cfg = None - if gpc.is_initialized(ParallelMode.PIPELINE) and gpc.get_world_size(ParallelMode.PIPELINE) > 1: - assert hasattr(gpc.config, - 'schedule'), "Config 'schedule' not found in your configuration file for pipeline parallel training" + engine_cfg = gpc.config.get('engine', dict()) + schedule_cfg = engine_cfg.pop('schedule', None) + + schedule_type = None + if schedule_cfg is not None: + schedule_type = schedule_cfg.get('type', None) + + if schedule_type is not None: + # run customized schedule + schedule_cfg['amp_type'] = amp_type + schedule_cfg['amp_config'] = amp_cfg + schedule = build_schedule(schedule_cfg) + elif gpc.is_initialized(ParallelMode.PIPELINE) and gpc.get_world_size(ParallelMode.PIPELINE) > 1: + assert schedule_cfg is not None, \ + "Config 'engine.schedule' not found in your configuration file for pipeline parallel training" schedule = PipelineSchedule( - amp_type=amp_type, amp_config=amp_cfg, **gpc.config.schedule.copy()) + amp_type=amp_type, amp_config=amp_cfg, **schedule_cfg.copy()) else: schedule = NoPipelineSchedule(amp_type=amp_type, amp_config=amp_cfg) - return model, train_dataloader, test_dataloader, criterion, optimizer, schedule, lr_scheduler + engine = Engine( + model=model, + optimizer=optimizer, + criterion=criterion, + step_schedule=schedule, + **gpc.config.get('engine', dict()) + ) + + return engine, train_dataloader, test_dataloader diff --git a/colossalai/nn/layer/parallel_2d/_operation.py b/colossalai/nn/layer/parallel_2d/_operation.py index 2c7eb8ac6..d9ecf2fad 100644 --- a/colossalai/nn/layer/parallel_2d/_operation.py +++ b/colossalai/nn/layer/parallel_2d/_operation.py @@ -7,6 +7,7 @@ from torch import Tensor from colossalai.context.parallel_mode import ParallelMode from colossalai.core import global_context as gpc from colossalai.utils import get_current_device +from torch.cuda.amp import custom_bwd, custom_fwd def matmul_2d(a, @@ -60,6 +61,7 @@ class Matmul_AB_2D(torch.autograd.Function): """Matrix multiplication for :math:`C = AB` """ @staticmethod + @custom_fwd(cast_inputs=torch.float16) def forward(ctx: Any, A: Tensor, B: Tensor, @@ -120,32 +122,32 @@ class Matmul_AB_2D(torch.autograd.Function): return out @staticmethod + @custom_bwd def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]: A, B = ctx.saved_tensors - A_grad = Matmul_ABT_2D.forward( - None, - output_grad, B, - ctx.summa_dim, ctx.A_shape, - ctx.row_rank, ctx.col_rank, - ctx.row_parallel_mode, - ctx.col_parallel_mode, - ctx.data_parallel_rank, - ctx.pipeline_parallel_rank, - ctx.pipeline_parallel_size, - ctx.tensor_parallel_size - ) - B_grad = Matmul_ATB_2D.forward( - None, - A, output_grad, - ctx.summa_dim, ctx.B_shape, - ctx.row_rank, ctx.col_rank, - ctx.row_parallel_mode, - ctx.col_parallel_mode, - ctx.data_parallel_rank, - ctx.pipeline_parallel_rank, - ctx.pipeline_parallel_size, - ctx.tensor_parallel_size - ) + with torch.no_grad(): + A_grad = Matmul_ABT_2D.apply( + output_grad, B, + ctx.summa_dim, ctx.A_shape, + ctx.row_rank, ctx.col_rank, + ctx.row_parallel_mode, + ctx.col_parallel_mode, + ctx.data_parallel_rank, + ctx.pipeline_parallel_rank, + ctx.pipeline_parallel_size, + ctx.tensor_parallel_size + ) + B_grad = Matmul_ATB_2D.apply( + A, output_grad, + ctx.summa_dim, ctx.B_shape, + ctx.row_rank, ctx.col_rank, + ctx.row_parallel_mode, + ctx.col_parallel_mode, + ctx.data_parallel_rank, + ctx.pipeline_parallel_rank, + ctx.pipeline_parallel_size, + ctx.tensor_parallel_size + ) return A_grad, B_grad, None, None, None, None, None, None, None, None, None, None @@ -153,6 +155,7 @@ class Matmul_ABT_2D(torch.autograd.Function): """Matrix multiplication for :math:`C = AB^T` """ @staticmethod + @custom_fwd(cast_inputs=torch.float16) def forward(ctx: Any, A: Tensor, B: Tensor, @@ -214,32 +217,33 @@ class Matmul_ABT_2D(torch.autograd.Function): return out @staticmethod + @custom_bwd def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]: A, B = ctx.saved_tensors - A_grad = Matmul_AB_2D.forward( - None, - output_grad, B, - ctx.summa_dim, ctx.A_shape, - ctx.row_rank, ctx.col_rank, - ctx.row_parallel_mode, - ctx.col_parallel_mode, - ctx.data_parallel_rank, - ctx.pipeline_parallel_rank, - ctx.pipeline_parallel_size, - ctx.tensor_parallel_size - ) - B_grad = Matmul_ATB_2D.forward( - None, - output_grad, A, - ctx.summa_dim, ctx.B_shape, - ctx.row_rank, ctx.col_rank, - ctx.row_parallel_mode, - ctx.col_parallel_mode, - ctx.data_parallel_rank, - ctx.pipeline_parallel_rank, - ctx.pipeline_parallel_size, - ctx.tensor_parallel_size - ) + + with torch.no_grad(): + A_grad = Matmul_AB_2D.apply( + output_grad, B, + ctx.summa_dim, ctx.A_shape, + ctx.row_rank, ctx.col_rank, + ctx.row_parallel_mode, + ctx.col_parallel_mode, + ctx.data_parallel_rank, + ctx.pipeline_parallel_rank, + ctx.pipeline_parallel_size, + ctx.tensor_parallel_size + ) + B_grad = Matmul_ATB_2D.apply( + output_grad, A, + ctx.summa_dim, ctx.B_shape, + ctx.row_rank, ctx.col_rank, + ctx.row_parallel_mode, + ctx.col_parallel_mode, + ctx.data_parallel_rank, + ctx.pipeline_parallel_rank, + ctx.pipeline_parallel_size, + ctx.tensor_parallel_size + ) return A_grad, B_grad, None, None, None, None, None, None, None, None, None, None @@ -247,6 +251,7 @@ class Matmul_ATB_2D(torch.autograd.Function): """Matrix multiplication for :math:`C = A^TB` """ @staticmethod + @custom_fwd(cast_inputs=torch.float16) def forward(ctx: Any, A: Tensor, B: Tensor, @@ -308,32 +313,33 @@ class Matmul_ATB_2D(torch.autograd.Function): return out @staticmethod + @custom_bwd def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]: A, B = ctx.saved_tensors - A_grad = Matmul_ABT_2D.forward( - None, - B, output_grad, - ctx.summa_dim, ctx.A_shape, - ctx.row_rank, ctx.col_rank, - ctx.row_parallel_mode, - ctx.col_parallel_mode, - ctx.data_parallel_rank, - ctx.pipeline_parallel_rank, - ctx.pipeline_parallel_size, - ctx.tensor_parallel_size - ) - B_grad = Matmul_AB_2D.forward( - None, - A, output_grad, - ctx.summa_dim, ctx.B_shape, - ctx.row_rank, ctx.col_rank, - ctx.row_parallel_mode, - ctx.col_parallel_mode, - ctx.data_parallel_rank, - ctx.pipeline_parallel_rank, - ctx.pipeline_parallel_size, - ctx.tensor_parallel_size - ) + + with torch.no_grad(): + A_grad = Matmul_ABT_2D.apply( + B, output_grad, + ctx.summa_dim, ctx.A_shape, + ctx.row_rank, ctx.col_rank, + ctx.row_parallel_mode, + ctx.col_parallel_mode, + ctx.data_parallel_rank, + ctx.pipeline_parallel_rank, + ctx.pipeline_parallel_size, + ctx.tensor_parallel_size + ) + B_grad = Matmul_AB_2D.apply( + A, output_grad, + ctx.summa_dim, ctx.B_shape, + ctx.row_rank, ctx.col_rank, + ctx.row_parallel_mode, + ctx.col_parallel_mode, + ctx.data_parallel_rank, + ctx.pipeline_parallel_rank, + ctx.pipeline_parallel_size, + ctx.tensor_parallel_size + ) return A_grad, B_grad, None, None, None, None, None, None, None, None, None, None @@ -341,6 +347,7 @@ class Add_Bias_2D(torch.autograd.Function): """Matrix add bias: :math:`C = A + b` """ @staticmethod + @custom_fwd(cast_inputs=torch.float16) def forward(ctx: Any, input: Tensor, bias: Tensor, @@ -384,6 +391,7 @@ class Add_Bias_2D(torch.autograd.Function): return output @staticmethod + @custom_bwd def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]: row_rank = ctx.row_rank col_rank = ctx.col_rank @@ -423,6 +431,7 @@ class Add_Bias_2D(torch.autograd.Function): class _LayerNorm_2D(torch.autograd.Function): @staticmethod + @custom_fwd(cast_inputs=torch.float32) def forward(ctx: Any, input: Tensor, E_x: Tensor, @@ -440,6 +449,7 @@ class _LayerNorm_2D(torch.autograd.Function): return output @staticmethod + @custom_bwd def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]: row_parallel_mode = ctx.row_parallel_mode col_parallel_mode = ctx.col_parallel_mode @@ -492,6 +502,7 @@ class _LayerNorm_2D(torch.autograd.Function): class _ViT_Split_Input_2D(torch.autograd.Function): @staticmethod + @custom_fwd(cast_inputs=torch.float16) def forward(ctx: Any, inputs: Tensor, batch_size: int, @@ -509,6 +520,7 @@ class _ViT_Split_Input_2D(torch.autograd.Function): return output @staticmethod + @custom_bwd def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]: # output_grad: [b/q, s, h/q] # grads: [b, s, h/q] diff --git a/colossalai/nn/lr_scheduler/__init__.py b/colossalai/nn/lr_scheduler/__init__.py index 82e28ff88..fd44686f0 100644 --- a/colossalai/nn/lr_scheduler/__init__.py +++ b/colossalai/nn/lr_scheduler/__init__.py @@ -1,5 +1,5 @@ from .cosine import CosineAnnealingLR, CosineAnnealingWarmupLR, FlatAnnealingLR, FlatAnnealingWarmupLR -from .linear import LinearWarmupLR, LinearWarmupDecay +from .linear import LinearWarmupLR from .multistep import MultiStepLR, MultiStepWarmupLR from .onecycle import OneCycleLR from .poly import PolynomialLR, PolynomialWarmupLR diff --git a/colossalai/nn/lr_scheduler/cosine.py b/colossalai/nn/lr_scheduler/cosine.py index 067636a3d..0df30baab 100644 --- a/colossalai/nn/lr_scheduler/cosine.py +++ b/colossalai/nn/lr_scheduler/cosine.py @@ -66,11 +66,10 @@ class CosineAnnealingWarmupLR(WarmupScheduler): :type last_epoch: int, optional """ - def __init__(self, optimizer, total_steps: int, warmup_steps: int = 0, eta_min: int = 0, last_epoch: int = -1, - **kwargs): + def __init__(self, optimizer, total_steps: int, warmup_steps: int = 0, eta_min: int = 0, last_epoch: int = -1): base_scheduler = _CosineAnnealingLR( - optimizer, total_steps - warmup_steps, eta_min=eta_min) - super().__init__(optimizer, warmup_steps, base_scheduler, last_epoch=last_epoch) + optimizer, total_steps - warmup_steps, eta_min=eta_min, last_epoch=last_epoch) + super().__init__(optimizer, warmup_steps, base_scheduler) @LR_SCHEDULERS.register_module diff --git a/colossalai/nn/lr_scheduler/delayed.py b/colossalai/nn/lr_scheduler/delayed.py index c8972c922..173d2f52c 100644 --- a/colossalai/nn/lr_scheduler/delayed.py +++ b/colossalai/nn/lr_scheduler/delayed.py @@ -55,7 +55,7 @@ class DelayerScheduler(_LRScheduler): class WarmupScheduler(_LRScheduler): - """ Starts with a linear warmup lr schedule until it reaches N epochs the applies a scheduler + """ Starts with a linear warmup lr schedule until it reaches N epochs the applies a scheduler :param optimizer: Wrapped optimizer. :type optimizer: torch.optim.Optimizer @@ -66,11 +66,8 @@ class WarmupScheduler(_LRScheduler): :param last_epoch: The index of last epoch, defaults to -1 :type last_epoch: int, optional """ - def __init__(self, optimizer, warmup_epochs, after_scheduler, last_epoch=-1): - if warmup_epochs < 0: - raise ValueError(f'warmup_epochs must >= 0, got {warmup_epochs}') - self.warmup_epochs = warmup_epochs + self.warmup_epochs = int(warmup_epochs) self.after_scheduler = after_scheduler self.finished = False super().__init__(optimizer, last_epoch) @@ -79,14 +76,10 @@ class WarmupScheduler(_LRScheduler): if self.last_epoch >= self.warmup_epochs: if not self.finished: self.after_scheduler.base_lrs = self.base_lrs - # reset lr to base_lr - for group, base_lr in zip(self.optimizer.param_groups, self.base_lrs): - group['lr'] = base_lr self.finished = True - with _enable_get_lr_call(self.after_scheduler): - return self.after_scheduler.get_lr() + return self.after_scheduler.get_lr() - return [(self.last_epoch + 1) / (self.warmup_epochs + 1) * lr for lr in self.base_lrs] + return [(self.last_epoch + 1) / self.warmup_epochs * lr for lr in self.base_lrs] def step(self, epoch=None): if self.finished: diff --git a/colossalai/nn/lr_scheduler/linear.py b/colossalai/nn/lr_scheduler/linear.py index afc68c5a7..b9498baf0 100644 --- a/colossalai/nn/lr_scheduler/linear.py +++ b/colossalai/nn/lr_scheduler/linear.py @@ -28,18 +28,3 @@ class LinearWarmupLR(_LRScheduler): else: return [(self.total_steps - self.last_epoch) / (self.total_steps - self.warmup_steps) * lr for lr in self.base_lrs] - - -@LR_SCHEDULERS.register_module -class LinearWarmupDecay(_LRScheduler): - def __init__(self, optimizer, total_steps: int, warmup_steps: int = 0, last_epoch: int = -1, **kwargs): - self.warmup_steps = int(warmup_steps) - self.total_steps = total_steps - super().__init__(optimizer, last_epoch=last_epoch) - - def get_lr(self): - if self.last_epoch < self.warmup_steps: - return [(self.last_epoch + 1) / self.warmup_steps * lr for lr in self.base_lrs] - else: - return [(self.total_steps - self.last_epoch - 1) / (self.total_steps - self.warmup_steps) * lr for lr in - self.base_lrs] diff --git a/colossalai/nn/lr_scheduler/multistep.py b/colossalai/nn/lr_scheduler/multistep.py index 46420765c..5def4a1fa 100644 --- a/colossalai/nn/lr_scheduler/multistep.py +++ b/colossalai/nn/lr_scheduler/multistep.py @@ -27,12 +27,7 @@ class MultiStepLR(_MultiStepLR): :type last_epoch: int, optional """ - def __init__(self, optimizer, total_steps: int, milestones: List[int] = None, gamma: float = 0.1, - num_steps_per_epoch: int = -1, last_epoch: int = -1, **kwargs): - if num_steps_per_epoch <= 0: - raise ValueError( - f'num_steps_per_epoch must > 0, got {num_steps_per_epoch}') - milestones = [v * num_steps_per_epoch for v in milestones] + def __init__(self, optimizer, total_steps: int, milestones: List[int] = None, gamma: float = 0.1, last_epoch: int = -1, **kwargs): super().__init__(optimizer, milestones, gamma=gamma, last_epoch=last_epoch) @@ -57,14 +52,11 @@ class MultiStepWarmupLR(WarmupScheduler): """ def __init__(self, optimizer, total_steps: int, warmup_steps: int = 0, milestones: List[int] = None, - gamma: float = 0.1, num_steps_per_epoch: int = -1, last_epoch: int = -1, **kwargs): + gamma: float = 0.1, last_epoch: int = -1, **kwargs): if len(milestones) == 0: raise ValueError('milestones cannot be empty') - if num_steps_per_epoch <= 0: - raise ValueError( - f'num_steps_per_epoch must > 0, got {num_steps_per_epoch}') - milestones = [v * num_steps_per_epoch - warmup_steps for v in milestones if v * - num_steps_per_epoch >= warmup_steps] + milestones = [ + v - warmup_steps for v in milestones if v >= warmup_steps] base_scheduler = _MultiStepLR(optimizer, milestones=milestones, gamma=gamma) super().__init__(optimizer, warmup_steps, base_scheduler, last_epoch=last_epoch) diff --git a/colossalai/nn/lr_scheduler/torch.py b/colossalai/nn/lr_scheduler/torch.py index 3ac0121ff..e739084b6 100644 --- a/colossalai/nn/lr_scheduler/torch.py +++ b/colossalai/nn/lr_scheduler/torch.py @@ -1,7 +1,7 @@ from torch.optim.lr_scheduler import LambdaLR as _LambdaLR from torch.optim.lr_scheduler import MultiplicativeLR as _MultiplicativeLR from torch.optim.lr_scheduler import StepLR as _StepLR -from torch.optim.lr_scheduler import _LRScheduler +from torch.optim.lr_scheduler import ExponentialLR as _ExponentialLR from colossalai.registry import LR_SCHEDULERS @@ -25,11 +25,8 @@ class LambdaLR(_LambdaLR): :type last_epoch: int, optional """ - def __init__(self, optimizer, total_steps, lr_lambda=None, num_steps_per_epoch: int = -1, - last_epoch: int = -1) -> None: - def func(step): return lr_lambda(step // num_steps_per_epoch) - - super().__init__(optimizer, func, last_epoch=last_epoch) + def __init__(self, optimizer, total_steps, lr_lambda=None, last_epoch: int = -1) -> None: + super().__init__(optimizer, lr_lambda, last_epoch=last_epoch) @LR_SCHEDULERS.register_module @@ -51,11 +48,8 @@ class MultiplicativeLR(_MultiplicativeLR): :type last_epoch: int, optional """ - def __init__(self, optimizer, total_steps, lr_lambda=None, num_steps_per_epoch: int = -1, - last_epoch: int = -1) -> None: - def func(step): return lr_lambda(step // num_steps_per_epoch) - - super().__init__(optimizer, func, last_epoch=last_epoch) + def __init__(self, optimizer, total_steps, lr_lambda=None, last_epoch: int = -1) -> None: + super().__init__(optimizer, lr_lambda, last_epoch=last_epoch) @LR_SCHEDULERS.register_module @@ -79,14 +73,13 @@ class StepLR(_StepLR): :type last_epoch: int, optional """ - def __init__(self, optimizer, total_steps, step_size: int = 1, gamma: float = 0.1, num_steps_per_epoch: int = -1, - last_epoch: int = -1) -> None: - super().__init__(optimizer, step_size * num_steps_per_epoch, + def __init__(self, optimizer, total_steps, step_size: int = 1, gamma: float = 0.1, last_epoch: int = -1) -> None: + super().__init__(optimizer, step_size, gamma=gamma, last_epoch=last_epoch) @LR_SCHEDULERS.register_module -class ExponentialLR(_LRScheduler): +class ExponentialLR(_ExponentialLR): """Decays the learning rate of each parameter group by gamma every epoch. When last_epoch=-1, sets initial lr as lr @@ -102,21 +95,6 @@ class ExponentialLR(_LRScheduler): :type last_epoch: int, optional """ - def __init__(self, optimizer, total_steps, gamma: float = 1.0, num_steps_per_epoch: int = -1, + def __init__(self, optimizer, total_steps, gamma: float = 1.0, last_epoch: int = -1) -> None: - self.gamma = gamma - self.num_steps_per_epoch = num_steps_per_epoch - super().__init__(optimizer, last_epoch=last_epoch) - - def get_lr(self): - if self.last_epoch == 0: - return self.base_lrs - elif (self.last_epoch + 1) % self.num_steps_per_epoch == 0: - return [group['lr'] * self.gamma - for group in self.optimizer.param_groups] - return [group['lr'] - for group in self.optimizer.param_groups] - - def _get_closed_form_lr(self): - return [base_lr * self.gamma ** (self.last_epoch // self.num_steps_per_epoch) - for base_lr in self.base_lrs] + super().__init__(optimizer, gamma, last_epoch=last_epoch) diff --git a/colossalai/nn/optimizer/_utils.py b/colossalai/nn/optimizer/_utils.py index 1be8ffc1b..6cd92bb38 100644 --- a/colossalai/nn/optimizer/_utils.py +++ b/colossalai/nn/optimizer/_utils.py @@ -106,7 +106,7 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2): tensor_parallel_norm = _calc_lp(tensor_parallel_grads, norm_type) no_tensor_parallel_grads = _calc_lp( no_tensor_parallel_grads, norm_type) - if gpc.is_initialized(ParallelMode.TENSOR): + if gpc.is_initialized(ParallelMode.TENSOR) and len(tensor_parallel_grads) > 0: # Sum across all model-parallel GPUs. torch.distributed.all_reduce(tensor_parallel_norm, op=torch.distributed.ReduceOp.SUM, diff --git a/colossalai/nn/optimizer/zero_redundancy_optimizer_level_2.py b/colossalai/nn/optimizer/zero_redundancy_optimizer_level_2.py index 17e277843..1a57c5876 100644 --- a/colossalai/nn/optimizer/zero_redundancy_optimizer_level_2.py +++ b/colossalai/nn/optimizer/zero_redundancy_optimizer_level_2.py @@ -6,6 +6,7 @@ import math import torch import torch.distributed as dist + try: from deepspeed.git_version_info import version from deepspeed.moe.utils import is_moe_param @@ -13,7 +14,7 @@ try: from deepspeed.ops.op_builder import UtilsBuilder from deepspeed.runtime.zero.config import ZERO_OPTIMIZATION_GRADIENTS except ImportError: - print('DeepSpeed is required if you want to use ZeRO.') + pass from packaging import version as pkg_version from torch._six import inf from torch.distributed.distributed_c10d import _get_global_rank @@ -251,7 +252,7 @@ class ZeroRedundancyOptimizer_Level_2(Optimizer): self.nccl_start_alignment_factor = 2 assert ( - allgather_bucket_size % self.nccl_start_alignment_factor == 0), f"allgather_bucket_size must be a multiple of nccl_start_alignment_factor, {self.nccl_start_alignment_factor} " + allgather_bucket_size % self.nccl_start_alignment_factor == 0), f"allgather_bucket_size must be a multiple of nccl_start_alignment_factor, {self.nccl_start_alignment_factor} " self.all_reduce_print = False self.dtype = self.optimizer.param_groups[0]['params'][0].dtype @@ -759,7 +760,7 @@ class ZeroRedundancyOptimizer_Level_2(Optimizer): elif start_index > current_index and start_index < (current_index + param_size): assert ( - first_offset == 0), "This can happen either zero or only once as this must be the first tensor in the partition" + first_offset == 0), "This can happen either zero or only once as this must be the first tensor in the partition" first_offset = start_index - current_index set_key_value_list(self.param_to_partition_ids[i], @@ -803,7 +804,7 @@ class ZeroRedundancyOptimizer_Level_2(Optimizer): def report_ipg_memory_usage(self, tag, param_elems): elem_count = self.elements_in_ipg_bucket + param_elems percent_of_bucket_size = ( - 100.0 * elem_count) // self.reduce_bucket_size + 100.0 * elem_count) // self.reduce_bucket_size if self.verbose: report_memory_usage( f"{tag}: elems in_bucket {self.elements_in_ipg_bucket} param {param_elems} max_percent {percent_of_bucket_size}" @@ -1491,7 +1492,7 @@ class ZeroRedundancyOptimizer_Level_2(Optimizer): params_in_partition.append(tensor) assert ( - first_offset == 0), "This can happen either zero or only once as this must be the first tensor in the partition" + first_offset == 0), "This can happen either zero or only once as this must be the first tensor in the partition" first_offset = start_index - current_index else: @@ -1799,7 +1800,7 @@ class ZeroRedundancyOptimizer_Level_2(Optimizer): num_elements = shard_size assert shard_size * \ - num_shards <= partitioned_params[partition_id].numel() + num_shards <= partitioned_params[partition_id].numel() for shard_id in range(num_shards): @@ -2248,7 +2249,7 @@ def estimate_zero2_model_states_mem_needs(total_params, if cpu_offload: gpu_mem = 2 * total_params cpu_mem = total_params * \ - max(4 * total_gpus, 16) * additional_buffer_factor + max(4 * total_gpus, 16) * additional_buffer_factor else: gpu_mem = 4 * total_params + int(16 * total_params / total_gpus) cpu_mem = total_params * 4 * num_gpus_per_node * additional_buffer_factor diff --git a/colossalai/nn/optimizer/zero_redundancy_optimizer_level_3.py b/colossalai/nn/optimizer/zero_redundancy_optimizer_level_3.py index 6f5d7969c..4e54f3cd3 100644 --- a/colossalai/nn/optimizer/zero_redundancy_optimizer_level_3.py +++ b/colossalai/nn/optimizer/zero_redundancy_optimizer_level_3.py @@ -21,7 +21,7 @@ try: from deepspeed.runtime.zero.partition_parameters import * from deepspeed.runtime.zero.partition_parameters import _init_external_params except ImportError: - print('DeepSpeed is required if you want to use ZeRO.') + pass from torch._six import inf from torch.distributed.distributed_c10d import _get_global_rank diff --git a/colossalai/registry/__init__.py b/colossalai/registry/__init__.py index 99aedc495..1de1c98ae 100644 --- a/colossalai/registry/__init__.py +++ b/colossalai/registry/__init__.py @@ -20,3 +20,4 @@ TRANSFORMS = Registry('transforms', third_party_library=[transforms]) PIPE_ALLOC_POLICY = Registry('pipeline_allocation_policy') SAMPLERS = Registry('samplers') LR_SCHEDULERS = Registry('lr_schedulers') +SCHEDULE = Registry('schedules') diff --git a/colossalai/trainer/__init__.py b/colossalai/trainer/__init__.py index 34e38d54a..57f7b7495 100644 --- a/colossalai/trainer/__init__.py +++ b/colossalai/trainer/__init__.py @@ -1,5 +1,5 @@ from ._trainer import Trainer from .hooks import * -from .metric import Loss, Accuracy2D, Accuracy3D, Accuracy2p5D +from .metric import Loss, Accuracy2D, Accuracy3D, Accuracy2p5D, LearningRate -__all__ = ['Trainer', 'Loss', 'Accuracy3D', 'Accuracy2D', 'Accuracy2p5D'] +__all__ = ['Trainer', 'Loss', 'Accuracy3D', 'Accuracy2D', 'Accuracy2p5D', 'LearningRate'] diff --git a/colossalai/trainer/_trainer.py b/colossalai/trainer/_trainer.py index 673349640..96a82d995 100644 --- a/colossalai/trainer/_trainer.py +++ b/colossalai/trainer/_trainer.py @@ -1,7 +1,6 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- -from typing import Optional from typing import Union, List import torch @@ -10,12 +9,11 @@ from torch.utils.data import DataLoader from tqdm import tqdm from colossalai.builder import build_hooks -from colossalai.checkpointing import save_checkpoint, load_checkpoint, get_checkpoint_path -from colossalai.context import Config from colossalai.engine import Engine from colossalai.logging import get_global_dist_logger -from colossalai.utils import get_global_multitimer, is_dp_rank_0, is_tp_rank_0, is_no_pp_or_last_stage from colossalai.nn.data import DataParallelSampler +from colossalai.utils import MultiTimer +from colossalai.utils import is_dp_rank_0, is_tp_rank_0, is_no_pp_or_last_stage class Trainer: @@ -30,43 +28,31 @@ class Trainer: :type hoooks_cfg: Config, optional :type verbose: bool, optional """ + def __init__(self, engine: Engine, - hooks_cfg: Optional[Config] = None, - verbose: bool = False): + verbose: bool = False, + timer: MultiTimer = None): # training-ralated params self._engine = engine - self._max_epochs = float('inf') - self._max_steps = float('inf') + self._max_epochs = 0 self._cur_epoch = 0 + self._max_steps = 0 self._cur_step = 0 - - # data-related params - self._train_dataloader = None - self._test_dataloader = None + self._steps_per_epoch = 0 # misc params - self._display_progress = False self._logger = get_global_dist_logger() self._verbose = verbose # hooks can store states in this dict, and could be consumed by other hooks - self.states = {} + self.states = dict() # build hooks self.hooks = list() - if hooks_cfg is not None: - for cfg in hooks_cfg: - hook = build_hooks(cfg, self) - self.hooks.append(hook) - self.hooks.sort(key=lambda hook: hook.priority) - if self._verbose: - for hook in self.hooks: - self._logger.info( - f'build {hook.__class__.__name__} for train, priority = {hook.priority}', ranks=[0]) - # timer - self._timer = get_global_multitimer() + # multi-timer for time benchmarking + self._timer = timer @property def cur_epoch(self): @@ -74,13 +60,65 @@ class Trainer: """ return self._cur_epoch + @cur_epoch.setter + def cur_epoch(self, epoch: int): + """Set how many epochs have been processed. + """ + # allow setter for training resumption + self._cur_epoch = epoch + @property def cur_step(self): """Returns how many iteration steps have been processed. """ return self._cur_step - def call_hooks(self, func, output=None): + @property + def max_epochs(self): + return self._max_epochs + + @property + def max_steps(self): + return self._max_steps + + @property + def steps_per_epoch(self): + return self._steps_per_epoch + + @property + def engine(self): + return self._engine + + @engine.setter + def engine(self, engine_: Engine): + self._engine = engine_ + + def _set_current_step(self, epoch: int): + """Sets current step number. + + :param epoch: Step number to be set + :type epoch: int + """ + self._cur_step = epoch * self._steps_per_epoch + + def _call_timer(self, action: str, item: str, *args, **kwargs) -> None: + """Call timer funciton with a given timer name. + + :param action: Function to be called on timer + :type action: str + :param item: Name of the timer + :type item: str + """ + + if self._timer is not None: + getattr(self._timer, action)(item, *args, **kwargs) + + def _reset_states(self) -> None: + """Clear trainer states + """ + self.states = dict() + + def _call_hooks(self, func, output=None): """Calls specific hooks in the current time point. :param func: A string represents the time point @@ -95,161 +133,186 @@ class Trainer: else: getattr(hook, func)(*output) - def exceed_max_step(self): - """Checks whether the trainer exceeds the maximum number of runnning iterations. + @staticmethod + def _should_display_progress(display_progress: bool): + """ Only display progress on DP rank 0, TP rank 0 and PP last rank """ - return self._cur_step >= self._max_steps + return display_progress and is_dp_rank_0() and is_tp_rank_0() and is_no_pp_or_last_stage() - def set_epoch(self, epoch): - """Sets current epoch number. - - :param epoch: Epoch number to be set - :type epoch: int - """ - self._cur_epoch = epoch - - def _recover_steps(self): - step = self.cur_step * self._engine.schedule.num_steps - self._cur_step = step - - def _set_display_progress(self, display_progress: bool): - self._display_progress = display_progress and is_dp_rank_0( - ) and is_tp_rank_0() and is_no_pp_or_last_stage() - - def _train_epoch(self, epoch: int = None): + def _train_epoch(self, + train_dataloader: DataLoader, + epoch: int = None, + display_progress: bool = False): # set sampler epoch if epoch is not None and \ - hasattr(self._engine.train_dataloader, 'sampler') and \ - isinstance(self._engine.train_dataloader.sampler, DataParallelSampler): - self._engine.train_dataloader.sampler.set_epoch(epoch) + hasattr(train_dataloader, 'sampler') and \ + isinstance(train_dataloader.sampler, DataParallelSampler): + train_dataloader.sampler.set_epoch(epoch) + # set training state self._engine.train() - - progress = range(self._engine.schedule.num_steps) - if self._display_progress: + data_iter = iter(train_dataloader) + progress = range(self._steps_per_epoch) + if display_progress: if epoch is None: progress = tqdm(progress, desc='[Train]') else: progress = tqdm(progress, desc=f'[Epoch {epoch} train]') # train 1 epoch - self.call_hooks('before_train_epoch') - self._timer.start('train-epoch') - for _ in progress: + self._call_hooks('before_train_epoch') + self._call_timer(action='start', item='train-epoch') + for i in progress: + self._call_hooks('before_train_iter') + self._call_timer(action='start', item='train-step') + + if i == self._steps_per_epoch - 1: + is_last_iteration = True + else: + is_last_iteration = False + + # run 1 training step + logits, label, loss = self._engine.step(data_iter, is_last_iteration) + self._call_timer(action='stop', item='train-step', keep_in_history=True) + self._call_hooks('after_train_iter', output=(logits, label, loss)) + self._cur_step += 1 - self.call_hooks('before_train_iter') - self._timer.start('train-step') - logits, label, loss = self._engine.step() - self._timer.stop('train-step', keep_in_history=True) - self.call_hooks('after_train_iter', output=(logits, label, loss)) - - if self.exceed_max_step(): - # stop when max iter is reached + # stop when max iter is reached + if self._exceed_max_step(): break - self._timer.stop('train-epoch', keep_in_history=True) - self.call_hooks('after_train_epoch') - self._timer.reset('train-step') + + self._call_timer(action='stop', item='train-epoch', keep_in_history=True) + self._call_hooks('after_train_epoch') + self._call_timer(action='reset', item='train-step') def _eval(self, + test_dataloader: DataLoader, epoch: int = None, - return_loss: bool = True): + display_progress: bool = False): # switch engine status self._engine.eval() - self.call_hooks('before_test') + data_iter = iter(test_dataloader) + num_steps = len(test_dataloader) + + self._call_hooks('before_test') with torch.no_grad(): # prepare progress bar - progress = range(self._engine.schedule.num_steps) - if self._display_progress: + progress = range(num_steps) + if display_progress: desc = 'Evaluation' if epoch is not None: desc = '[Epoch %d val]' % epoch progress = tqdm(progress, desc=desc) - self.call_hooks('before_test_epoch') - self._timer.start('test-epoch') + self._call_hooks('before_test_epoch') + self._call_timer(action='start', item='test-epoch') for _ in progress: - self.call_hooks('before_test_iter') - self._timer.start('test-step') - logits, label, loss = self._engine.step( - return_loss=return_loss) - self._timer.stop('test-step', keep_in_history=True) - self.call_hooks('after_test_iter', - output=(logits, label, loss)) - self._timer.stop('test-epoch', keep_in_history=True) - self.call_hooks('after_test_epoch') - self.call_hooks('after_test') - self._timer.reset('test-step') - self._timer.reset('test-epoch') + self._call_hooks('before_test_iter') + self._call_timer(action='start', item='test-step') + logits, label, loss = self._engine.step(data_iter, return_loss=True) + self._call_timer(action='stop', item='test-step', keep_in_history=True) + self._call_hooks('after_test_iter', + output=(logits, label, loss)) + self._call_timer(action='stop', item='test-epoch', keep_in_history=True) + self._call_hooks('after_test_epoch') + self._call_hooks('after_test') + self._call_timer(action='reset', item='test-step') + self._call_timer(action='reset', item='test-epoch') + + def _exceed_max_step(self): + return self._max_steps is not None and self._cur_step > self._max_steps def fit(self, train_dataloader: DataLoader, - test_dataloader: DataLoader = None, - max_epochs: int = None, + epochs: int, max_steps: int = None, + test_dataloader: DataLoader = None, test_interval: int = 1, - display_progress: bool = False): + hooks_cfg: dict = None, + display_progress: bool = False, + ): """Trains the model to fit training data. :param train_dataloader: DataLoader in training - :param test_dataloader: DataLoader in testing - :param max_epochs: Maximum number of epoches + :param epochs: Maximum number of epoches :param max_steps: Maximum number of running iterations + :param test_dataloader: DataLoader in testing :param test_interval: Interval of testing + :param hooks_cfg: A list of hook configuration :param display_progress: If True, the training progress will be printed :type train_dataloader: DataLoader - :type test_dataloader: DataLoader - :type max_epochs: int + :type epochs: int :type max_steps: int + :type test_dataloader: DataLoader :type test_interval: int + :type hooks_cfg: dict :type display_progress: bool + :type gradient_accumulation: int """ - # prepare dataloaders - self._train_dataloader = train_dataloader - self._engine.set_dataloader(self._train_dataloader, train=True) - self._engine.train() + # set epochs and steps, consider gradient accumulation + self._steps_per_epoch = len(train_dataloader) // self._engine.gradient_accumulation + self._max_steps = max_steps + self._max_epochs = epochs + # check if testing is required should_test = False if test_dataloader is not None: - self._test_dataloader = test_dataloader - self._engine.set_dataloader(self._test_dataloader, train=False) should_test = True - # decide the - if max_epochs is not None: - self._max_epochs = max_epochs - if max_steps is not None: - self._max_steps = max_steps - self._set_display_progress(display_progress) + display_progress = self._should_display_progress(display_progress) + + # reset hooks + self._reset_states() + self.hooks = list() + + # build hooks + if hooks_cfg is not None: + for cfg in hooks_cfg: + hook = build_hooks(cfg, self) + self.hooks.append(hook) + self.hooks.sort(key=lambda hook: hook.priority) + if self._verbose: + for hook in self.hooks: + self._logger.info( + f'build {hook.__class__.__name__} for training, priority = {hook.priority}', ranks=[0]) + self._logger.info("Lower value means higher priority for calling hook function") # start train - self.call_hooks('before_train') + self._engine.train() + self._call_hooks('before_train') # recover step value if resuming training - if self.cur_epoch != 0: - self._recover_steps() - last_epoch = self._cur_epoch + if self.cur_epoch != 0: + self._set_current_step(last_epoch) - for epoch in range(last_epoch, self._max_epochs): - self._cur_epoch += 1 - + for epoch in range(last_epoch, epochs): # train for one epoch - self._train_epoch(epoch) + self._train_epoch( + train_dataloader=train_dataloader, + epoch=epoch, + display_progress=display_progress + ) # start eval if should_test and epoch % test_interval == 0: - self._eval(epoch, return_loss=True) + self._eval(test_dataloader=test_dataloader, + display_progress=display_progress, + epoch=epoch, + ) + + self._cur_epoch += 1 # check for termination - if self.exceed_max_step(): + if self._exceed_max_step(): self._logger.info( - f"Max number of steps {self._max_steps} has been reached, training is stopped automatically") + f"Max number of steps {max_steps} has been reached, training is stopped automatically") break - self.call_hooks('after_train') - self._timer.reset('train-epoch') + self._call_hooks('after_train') + self._call_timer('reset', 'train-epoch') def evaluate(self, test_dataloader: DataLoader, @@ -261,15 +324,13 @@ class Trainer: :type test_dataloader: DataLoader :type display_progress: bool, optional """ - # set dataloader - self._test_dataloader = test_dataloader - self._engine.set_dataloader(self._test_dataloader, train=True) - - # set - self._set_display_progress(display_progress) + # set display + display_progress = self._should_display_progress(display_progress) # eval - self._eval(return_loss=True) + self._eval(test_dataloader=test_dataloader, + display_progress=display_progress, + ) def predict(self, data: Union[Tensor, List[Tensor]]): """Uses trained model to make a prediction for a tensor or a tensor list. @@ -289,45 +350,6 @@ class Trainer: # prepare a list of (data, label) to make it iterable # for compatibility with schedule simple_dataloader = [(data, None)] - self._engine.set_dataloader(simple_dataloader) - output, _, _ = self._engine.step(return_loss=False) + data_iter = iter(simple_dataloader) + output, _, _ = self._engine.step(data_iter, return_loss=False) return output - - def save(self, path: str, suffix: str = ''): - """Saves the model to a file. - - :param path: Relative path of the file - :param suffix: Suffix of the file - :type path: str - :type suffix: str, optional - """ - save_path = get_checkpoint_path(path, - self._cur_epoch, - suffix=suffix) - save_checkpoint(save_path, self._cur_epoch, self._engine.get_model(), - self._engine.get_optimizer(), - self._engine.get_lr_scheduler()) - - def load(self, - path: str, - finetune: bool = False, - strict: bool = False): - """Loads parameters to the model from a file. - - :param path: Relative path of the file - :param finetune: Whether allows to load a part of the model - :param strict: Whether loads a model that has the same shape of parameters - :type path: str - :type finetune: bool, optional - :type strict: bool, optional - """ - last_epoch, _ = load_checkpoint(path, - self._engine.get_model(), - self._engine.get_optimizer(), - self._engine.get_lr_scheduler(), - finetune=finetune, - strict=strict) - if finetune: - self.set_epoch(0) - else: - self.set_epoch(last_epoch) diff --git a/colossalai/trainer/hooks/__init__.py b/colossalai/trainer/hooks/__init__.py index 2cc3c78b7..952bef8b9 100644 --- a/colossalai/trainer/hooks/__init__.py +++ b/colossalai/trainer/hooks/__init__.py @@ -2,10 +2,12 @@ from ._base_hook import BaseHook from ._checkpoint_hook import SaveCheckpointHook, LoadCheckpointHook from ._metric_hook import LossHook, Accuracy2DHook, AccuracyHook, MetricHook from ._log_hook import LogMetricByEpochHook, TensorboardHook, LogTimingByEpochHook, LogMemoryByEpochHook +from ._lr_scheduler_hook import LRSchedulerHook __all__ = [ 'BaseHook', 'MetricHook', 'LoadCheckpointHook', 'SaveCheckpointHook', 'LossHook', 'AccuracyHook', 'Accuracy2DHook', 'LogMetricByEpochHook', 'TensorboardHook', 'LogTimingByEpochHook', 'LogMemoryByEpochHook', + 'LRSchedulerHook' ] diff --git a/colossalai/trainer/hooks/_checkpoint_hook.py b/colossalai/trainer/hooks/_checkpoint_hook.py index 49fd28948..e1d9d4714 100644 --- a/colossalai/trainer/hooks/_checkpoint_hook.py +++ b/colossalai/trainer/hooks/_checkpoint_hook.py @@ -3,13 +3,13 @@ import os.path as osp -import torch.distributed as dist - -from colossalai.checkpointing import get_latest_checkpoint_path, get_checkpoint_path from colossalai.registry import HOOKS -from colossalai.trainer.hooks import BaseHook from colossalai.trainer import Trainer +from colossalai.trainer.hooks import BaseHook from colossalai.utils import is_dp_rank_0 +from colossalai.utils.checkpointing import get_latest_checkpoint_path, get_checkpoint_path +from colossalai.utils.checkpointing import save_checkpoint, load_checkpoint +from ._lr_scheduler_hook import LRSchedulerHook @HOOKS.register_module @@ -33,7 +33,7 @@ class SaveCheckpointHook(BaseHook): interval: int = 1, checkpoint_dir: str = None, suffix: str = '', - priority: int = 0): + priority: int = 10): super().__init__(trainer=trainer, priority=priority) assert isinstance(trainer, Trainer), \ f'SaveCheckpointHook expects a Trainer, got {type(trainer)}' @@ -41,6 +41,16 @@ class SaveCheckpointHook(BaseHook): self.checkpoint_dir = checkpoint_dir self.suffix = suffix + # get lr scheduler from the LRSchedulerHook before train + self._lr_scheduler = None + + def before_train(self): + # check if lr scheduler is present in LRSchedulerHook + for hook in self.trainer.hooks: + if isinstance(hook, LRSchedulerHook): + self._lr_scheduler = hook.lr_scheduler + break + def after_train_epoch(self): """Saves the model after a training epoch. """ @@ -48,14 +58,18 @@ class SaveCheckpointHook(BaseHook): if self.trainer.cur_epoch % self.interval == 0: # only gpus with data parallel rank equals to 0 write to the disk if is_dp_rank_0(): - self.trainer.save(path=self.checkpoint_dir, suffix=self.suffix) + save_path = get_checkpoint_path(self.checkpoint_dir, + self.trainer.cur_epoch, + suffix=self.suffix) + + save_checkpoint(save_path, + self.trainer.cur_epoch, + self.trainer.engine.model, + self.trainer.engine.optimizer, + self._lr_scheduler) self.logger.info( f'checkpoint for epoch {self.trainer.cur_epoch} is saved to {self.checkpoint_dir}') - # wait until everyone is done - if dist.is_initialized(): - dist.barrier() - @HOOKS.register_module class LoadCheckpointHook(BaseHook): @@ -81,30 +95,46 @@ class LoadCheckpointHook(BaseHook): epoch: int = -1, finetune: bool = False, strict: bool = False, - priority: int = 10) -> None: + suffix: str = '', + priority: int = 0) -> None: + super().__init__(trainer=trainer, priority=priority) assert isinstance(trainer, Trainer), \ f'LoadLatestCheckpointHook excepts a Trainer, got {type(trainer)}' self.epoch = epoch self.checkpoint_dir = checkpoint_dir self.finetune = finetune + self.suffix = suffix self.strict = strict - super().__init__(trainer=trainer, priority=priority) def before_train(self): """Loads parameters to the model before training. """ + # check if lr scheduler is present in LRSchedulerHook + lr_scheduler = None + for hook in self.trainer.hooks: + if isinstance(hook, LRSchedulerHook): + lr_scheduler = hook.lr_scheduler + break + + # use latest checkpoint if epoch = -1 if self.epoch == -1: - path = get_latest_checkpoint_path(self.checkpoint_dir) + path = get_latest_checkpoint_path(self.checkpoint_dir, suffix=self.suffix) else: - path = get_checkpoint_path(self.checkpoint_dir, epoch=self.epoch) + path = get_checkpoint_path(self.checkpoint_dir, epoch=self.epoch, suffix=self.suffix) + if osp.exists(path): - self.trainer.load( - path, finetune=self.finetune, strict=self.strict) + last_epoch, _ = load_checkpoint(path, + self.trainer.engine.model, + self.trainer.engine.optimizer, + lr_scheduler, + finetune=self.finetune, + strict=self.strict) + if self.finetune: + self.trainer.cur_epoch = 0 + else: + self.trainer.cur_epoch = last_epoch + self.logger.info( f'loaded checkpoint from {path}') else: raise FileNotFoundError(f'checkpoint is not found at {path}') - - # Some utilities want to load a checkpoint without distributed being initialized - if dist.is_initialized(): - dist.barrier() diff --git a/colossalai/trainer/hooks/_log_hook.py b/colossalai/trainer/hooks/_log_hook.py index d7ed4bf56..3c3fdfc43 100644 --- a/colossalai/trainer/hooks/_log_hook.py +++ b/colossalai/trainer/hooks/_log_hook.py @@ -5,7 +5,7 @@ import os import os.path as osp import torch -from tensorboardX import SummaryWriter +from torch.utils.tensorboard import SummaryWriter from colossalai.context import ParallelMode from colossalai.core import global_context as gpc @@ -13,7 +13,7 @@ from colossalai.registry import HOOKS from colossalai.trainer._trainer import Trainer from colossalai.utils import get_global_multitimer, set_global_multitimer_status, report_memory_usage, is_dp_rank_0, \ is_tp_rank_0, is_no_pp_or_last_stage -from ._metric_hook import MetricHook +from ._base_hook import BaseHook def _format_number(val): @@ -24,7 +24,7 @@ def _format_number(val): return val -class EpochIntervalHook(MetricHook): +class EpochIntervalHook(BaseHook): def __init__(self, trainer: Trainer, interval: int = 1, priority: int = 1): super().__init__(trainer, priority) self._interval = interval @@ -45,7 +45,7 @@ class LogMetricByEpochHook(EpochIntervalHook): :type priority: int, optional """ - def __init__(self, trainer: Trainer, interval: int = 1, priority: int = 1) -> None: + def __init__(self, trainer: Trainer, interval: int = 1, priority: int = 10) -> None: super().__init__(trainer=trainer, interval=interval, priority=priority) self._is_rank_to_log = is_dp_rank_0() and is_tp_rank_0() and is_no_pp_or_last_stage() @@ -74,7 +74,7 @@ class LogMetricByEpochHook(EpochIntervalHook): @HOOKS.register_module -class TensorboardHook(MetricHook): +class TensorboardHook(BaseHook): """Specialized Hook to record the metric to Tensorboard. :param trainer: Trainer attached with current hook @@ -85,59 +85,71 @@ class TensorboardHook(MetricHook): :type priority: int, optional """ - def __init__(self, trainer: Trainer, log_dir: str, priority: int = 1) -> None: + def __init__(self, + trainer: Trainer, + log_dir: str, + dp_rank_0_only: bool = True, + tp_rank_0_only: bool = True, + priority: int = 10, + ) -> None: super().__init__(trainer=trainer, priority=priority) - self._is_rank_to_log = is_no_pp_or_last_stage() - if self._is_rank_to_log: + # create log dir + if not gpc.is_initialized(ParallelMode.GLOBAL) or gpc.get_global_rank() == 0: + os.makedirs(log_dir, exist_ok=True) + + # determine the ranks to generate tensorboard logs + self._is_valid_rank_to_log = is_no_pp_or_last_stage() + + if dp_rank_0_only: + self._is_valid_rank_to_log = self._is_valid_rank_to_log and is_dp_rank_0() + + if tp_rank_0_only: + self._is_valid_rank_to_log = self._is_valid_rank_to_log and is_tp_rank_0() + + if self._is_valid_rank_to_log: # create workspace on only one rank if gpc.is_initialized(ParallelMode.GLOBAL): rank = gpc.get_global_rank() else: rank = 0 - log_dir = osp.join(log_dir, f'rank_{rank}') - # create workspace - if not osp.exists(log_dir): - os.makedirs(log_dir) + log_dir = osp.join(log_dir, f'rank_{rank}') + os.makedirs(log_dir, exist_ok=True) self.writer = SummaryWriter( log_dir=log_dir, filename_suffix=f'_rank_{rank}') - def after_train_iter(self, *args): - for metric_name, metric_calculator in self.trainer.states['metrics']['train'].items(): + def _log_by_iter(self, mode: str): + for metric_name, metric_calculator in self.trainer.states['metrics'][mode].items(): if metric_calculator.epoch_only: continue val = metric_calculator.get_last_step_value() - if self._is_rank_to_log: - self.writer.add_scalar( - f'{metric_name}/train', val, self.trainer.cur_step) - def after_test_iter(self, *args): - for metric_name, metric_calculator in self.trainer.states['metrics']['test'].items(): - if metric_calculator.epoch_only: - continue - val = metric_calculator.get_last_step_value() - if self._is_rank_to_log: - self.writer.add_scalar(f'{metric_name}/test', val, + if self._is_valid_rank_to_log: + self.writer.add_scalar(f'{metric_name}/{mode}', val, self.trainer.cur_step) - def after_test_epoch(self): - for metric_name, metric_calculator in self.trainer.states['metrics']['test'].items(): + def _log_by_epoch(self, mode: str): + for metric_name, metric_calculator in self.trainer.states['metrics'][mode].items(): if metric_calculator.epoch_only: val = metric_calculator.get_accumulated_value() - if self._is_rank_to_log: - self.writer.add_scalar(f'{metric_name}/test', val, + if self._is_valid_rank_to_log: + self.writer.add_scalar(f'{metric_name}/{mode}', val, self.trainer.cur_step) + def after_test_iter(self, *args): + self._log_by_iter(mode='test') + + def after_test_epoch(self): + self._log_by_epoch(mode='test') + + def after_train_iter(self, *args): + self._log_by_iter(mode='train') + def after_train_epoch(self): - for metric_name, metric_calculator in self.trainer.states['metrics']['train'].items(): - if metric_calculator.epoch_only: - val = metric_calculator.get_accumulated_value() - if self._is_rank_to_log: - self.writer.add_scalar(f'{metric_name}/train', val, - self.trainer.cur_step) + self._log_by_epoch(mode='train') @HOOKS.register_module @@ -157,7 +169,7 @@ class LogTimingByEpochHook(EpochIntervalHook): def __init__(self, trainer: Trainer, interval: int = 1, - priority: int = 1, + priority: int = 10, log_eval: bool = True ) -> None: super().__init__(trainer=trainer, interval=interval, priority=priority) @@ -217,7 +229,7 @@ class LogMemoryByEpochHook(EpochIntervalHook): def __init__(self, trainer: Trainer, interval: int = 1, - priority: int = 1, + priority: int = 10, log_eval: bool = True ) -> None: super().__init__(trainer=trainer, interval=interval, priority=priority) diff --git a/colossalai/trainer/hooks/_lr_scheduler_hook.py b/colossalai/trainer/hooks/_lr_scheduler_hook.py new file mode 100644 index 000000000..ca483aebe --- /dev/null +++ b/colossalai/trainer/hooks/_lr_scheduler_hook.py @@ -0,0 +1,58 @@ +from torch import Tensor + +from colossalai.builder import build_lr_scheduler +from colossalai.registry import HOOKS +from ._metric_hook import MetricHook +from .._trainer import Trainer +from ..metric import LearningRate + + +@HOOKS.register_module +class LRSchedulerHook(MetricHook): + """Build LR scheduler + + :param trainer: Trainer attached with current hook + :type trainer: Trainer + :param lr_scheduler_cfg: The config of LR scheduler + :type lr_scheduler_cfg: dict + :param by_epoch: If `True`, the LR will be scheduled every epoch. Else, the LR will be scheduled every batch. Defaults to `True`. + :type by_epoch: bool + :param priority: Priority in the printing, hooks with small priority will be printed in front + :type priority: int, optional + """ + + def __init__(self, + trainer: Trainer, + lr_scheduler_cfg: dict, + by_epoch: bool = True, + store_lr_in_state: bool = True, + priority: int = 1, + ): + super().__init__(trainer=trainer, priority=priority) + self.by_epoch = by_epoch + + if by_epoch: + total_steps = trainer.max_epochs + else: + total_steps = trainer.max_epochs * trainer.steps_per_epoch + if trainer.max_steps is not None: + total_steps = min(total_steps, trainer.max_steps) + + lr_scheduler_cfg['total_steps'] = total_steps + + self.lr_scheduler = build_lr_scheduler( + lr_scheduler_cfg, trainer.engine.optimizer) + + if store_lr_in_state: + self.trainer.states['metrics']['train']['lr'] = LearningRate(epoch_only=by_epoch, + initial_lr=self.lr_scheduler.get_lr()[0]) + + def after_train_epoch(self): + if self.by_epoch: + self.lr_scheduler.step() + self.trainer.states['metrics']['train']['lr'].update(self.lr_scheduler.get_lr()[0]) + + def after_train_iter(self, output: Tensor, label: Tensor, loss: Tensor): + if not self.by_epoch: + self.lr_scheduler.step() + self.trainer.states['metrics']['train']['lr'].update(self.lr_scheduler.get_lr()[0]) diff --git a/colossalai/trainer/hooks/_metric_hook.py b/colossalai/trainer/hooks/_metric_hook.py index 241ec63d3..8c3478c71 100644 --- a/colossalai/trainer/hooks/_metric_hook.py +++ b/colossalai/trainer/hooks/_metric_hook.py @@ -21,9 +21,12 @@ class MetricHook(BaseHook): :type priority: int """ - def __init__(self, trainer: Trainer, priority: int): + def __init__(self, + trainer: Trainer, + priority: int, + ): super().__init__(trainer, priority) - self._is_stage_to_log = is_no_pp_or_last_stage() + self._is_stage_to_compute = is_no_pp_or_last_stage() self._check_metric_states_initialization() def _check_metric_states_initialization(self): @@ -41,33 +44,34 @@ class LossHook(MetricHook): :type priority: int, optional """ - def __init__(self, trainer: Trainer, priority: int = 10): + def __init__(self, trainer: Trainer, priority: int = 0): super().__init__(trainer, priority) - if self._is_stage_to_log: - self.metric = Loss(epoch_only=False) + if self._is_stage_to_compute: + self.train_loss = Loss(epoch_only=False) + self.test_loss = Loss(epoch_only=True) # register the metric calculator self.trainer.states['metrics']['train'][ - self.metric.__class__.__name__] = self.metric + self.train_loss.__class__.__name__] = self.train_loss self.trainer.states['metrics']['test'][ - self.metric.__class__.__name__] = self.metric + self.test_loss.__class__.__name__] = self.test_loss def before_train_epoch(self): - if self._is_stage_to_log: - self.metric.reset() + if self._is_stage_to_compute: + self.train_loss.reset() def after_train_iter(self, logits, label, loss): - if self._is_stage_to_log: - self.metric.update(loss) + if self._is_stage_to_compute: + self.train_loss.update(loss) def before_test_epoch(self): - if self._is_stage_to_log: - self.metric.reset() + if self._is_stage_to_compute: + self.test_loss.reset() def after_test_iter(self, logits, label, loss): - if self._is_stage_to_log: - self.metric.update(loss) + if self._is_stage_to_compute: + self.test_loss.update(loss) @HOOKS.register_module @@ -81,10 +85,10 @@ class Accuracy2DHook(MetricHook): :type priority: int, optional """ - def __init__(self, trainer: Trainer, priority: int = 10): + def __init__(self, trainer: Trainer, priority: int = 0): super().__init__(trainer, priority) - if self._is_stage_to_log: + if self._is_stage_to_compute: self.metric = Accuracy2D(epoch_only=True) # register the metric @@ -92,20 +96,20 @@ class Accuracy2DHook(MetricHook): self.metric.__class__.__name__] = self.metric def before_test(self): - if self._is_stage_to_log: + if self._is_stage_to_compute: self.metric.reset() def after_test_iter(self, logits, label, *args): - if self._is_stage_to_log: + if self._is_stage_to_compute: self.metric.update(logits, label) @HOOKS.register_module class Accuracy2p5DHook(MetricHook): - def __init__(self, trainer: Trainer, priority: int = 10): + def __init__(self, trainer: Trainer, priority: int = 0): super().__init__(trainer, priority) - if self._is_stage_to_log: + if self._is_stage_to_compute: self.metric = Accuracy2p5D(epoch_only=True) # register the metric @@ -113,11 +117,11 @@ class Accuracy2p5DHook(MetricHook): self.metric.__class__.__name__] = self.metric def before_test(self): - if self._is_stage_to_log: + if self._is_stage_to_compute: self.metric.reset() def after_test_iter(self, logits, label, *args): - if self._is_stage_to_log: + if self._is_stage_to_compute: self.metric.update(logits, label) @@ -138,7 +142,7 @@ class Accuracy3DHook(MetricHook): priority: int = 10): super().__init__(trainer, priority) - if self._is_stage_to_log: + if self._is_stage_to_compute: self.metric = Accuracy3D(epoch_only=True, input_parallel_mode=input_parallel_mode, weight_parallel_mode=weight_parallel_mode) @@ -148,11 +152,11 @@ class Accuracy3DHook(MetricHook): self.metric.__class__.__name__] = self.metric def before_test(self): - if self._is_stage_to_log: + if self._is_stage_to_compute: self.metric.reset() def after_test_iter(self, logits, label, *args): - if self._is_stage_to_log: + if self._is_stage_to_compute: self.metric.update(logits, label) @@ -166,10 +170,10 @@ class AccuracyHook(MetricHook): :type priority: int """ - def __init__(self, trainer: Trainer, priority: int = 10): + def __init__(self, trainer: Trainer, priority: int = 0): super().__init__(trainer, priority) - if self._is_stage_to_log: + if self._is_stage_to_compute: self.metric = Accuracy(epoch_only=True) # register the metric @@ -177,9 +181,9 @@ class AccuracyHook(MetricHook): self.metric.__class__.__name__] = self.metric def before_test(self): - if self._is_stage_to_log: + if self._is_stage_to_compute: self.metric.reset() def after_test_iter(self, logits, label, *args): - if self._is_stage_to_log: + if self._is_stage_to_compute: self.metric.update(logits, label) diff --git a/colossalai/trainer/metric.py b/colossalai/trainer/metric.py index 744e0e03a..b595d37b8 100644 --- a/colossalai/trainer/metric.py +++ b/colossalai/trainer/metric.py @@ -126,6 +126,33 @@ class Loss(Metric): return a < b +class LearningRate(Metric): + """A metric collector for learning rate. + + :param epoch_only: Whether the metric only read for the full epoch + :type epoch_only: bool + """ + + def __init__(self, epoch_only: bool, initial_lr: float = 0.): + super().__init__(epoch_only=epoch_only) + self.lr = 0. + + def reset(self) -> None: + pass + + def update(self, lr) -> None: + self.lr = lr + + def get_last_step_value(self): + return self.lr + + def get_accumulated_value(self): + return self.lr + + def is_better(a, b) -> bool: + pass + + class Accuracy(Metric): """A metric collector for accuracy. It only works for classification tasks. diff --git a/colossalai/checkpointing.py b/colossalai/utils/checkpointing.py similarity index 98% rename from colossalai/checkpointing.py rename to colossalai/utils/checkpointing.py index 17db1a1a5..d2cf050cc 100644 --- a/colossalai/checkpointing.py +++ b/colossalai/utils/checkpointing.py @@ -5,9 +5,9 @@ from typing import Tuple import torch -from .context import Config -from .context.parallel_mode import ParallelMode -from .core import global_context as gpc +from colossalai.context import Config +from colossalai.context.parallel_mode import ParallelMode +from colossalai.core import global_context as gpc __all__ = [ 'get_checkpoint_path', diff --git a/colossalai/utils/common.py b/colossalai/utils/common.py index 1496e77ac..d8c6663ba 100644 --- a/colossalai/utils/common.py +++ b/colossalai/utils/common.py @@ -27,7 +27,7 @@ def sync_model_param_in_dp(model): :param model: A pyTorch nn.model on whose parameters you check the consistency ''' - if gpc.is_initialized(ParallelMode.DATA) and gpc.get_world_size(ParallelMode.DATA) > 2: + if gpc.is_initialized(ParallelMode.DATA) and gpc.get_world_size(ParallelMode.DATA) > 1: for param in model.parameters(): ranks = gpc.get_ranks_in_group(ParallelMode.DATA) dist.broadcast(param, src=ranks[0], group=gpc.get_group(ParallelMode.DATA)) diff --git a/configs/resnet/resnet50.py b/configs/resnet/resnet50.py index 57b8b8304..d5ecbdfef 100644 --- a/configs/resnet/resnet50.py +++ b/configs/resnet/resnet50.py @@ -4,6 +4,7 @@ import os IMG_SIZE = 224 BATCH_SIZE = 256 +NUM_EPOCHS = 100 model = dict( type='VanillaResNet', @@ -67,8 +68,6 @@ loss = dict( type='CrossEntropyLoss' ) -max_epochs = 100 - from colossalai.engine import AMP_TYPE fp16 = dict( diff --git a/configs/sample_config.py b/configs/sample_config.py index bfc2d68e2..b9768d2c1 100644 --- a/configs/sample_config.py +++ b/configs/sample_config.py @@ -1,21 +1,20 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- +NUM_EPOCH = int + model = dict() train_data = dict() test_data = dict() optimizer = dict() loss = dict() -lr_scheduler = dict() fp16 = dict() zero = dict() gradient_handler = [] parallel = dict() - -num_epochs = int -num_steps = int +hooks = [] cudnn_benchmark = True cudnn_deterministic = False diff --git a/configs/vit/vit_2d.py b/configs/vit/vit_2d.py index 9d09eda2c..f36a03acc 100644 --- a/configs/vit/vit_2d.py +++ b/configs/vit/vit_2d.py @@ -8,10 +8,11 @@ BATCH_SIZE = 512 IMG_SIZE = 32 PATCH_SIZE = 4 DIM = 512 -NUM_ATTENTION_HEADS = 8 +NUM_ATTENTION_HEADS = 2 SUMMA_DIM = 2 NUM_CLASSES = 10 -DEPTH = 6 +DEPTH = 1 +NUM_EPOCHS = 60 train_data = dict( dataset=dict( @@ -127,14 +128,22 @@ hooks = [ dict(type='LogMetricByEpochHook'), dict(type='Accuracy2DHook'), dict(type='LossHook'), - dict(type='TensorboardHook', log_dir='./tfb_logs'), + dict( + type='LRSchedulerHook', + by_epoch=True, + lr_scheduler_cfg=dict( + type='LinearWarmupLR', + warmup_steps=5 + ) + ), + dict(type='TensorboardHook', log_dir='./tb_logs'), # dict(type='SaveCheckpointHook', interval=5, checkpoint_dir='./ckpt'), # dict(type='LoadCheckpointHook', epoch=20, checkpoint_dir='./ckpt') ] parallel = dict( pipeline=dict(size=1), - tensor=dict(size=4, mode='2d'), + tensor=dict(size=1, mode='2d'), ) # for fp16 training @@ -144,17 +153,11 @@ parallel = dict( # initial_scale=2 ** 8 # ) -lr_scheduler = dict( - type='LinearWarmupLR', - warmup_epochs=5 -) - # only needed when pipeline parallel is used # schedule = dict( # num_microbatches=8 # ) -num_epochs = 60 logging = dict( root_path='./logs' diff --git a/configs/vit/vit_3d.py b/configs/vit/vit_3d.py index 037e2c15e..ea605dac8 100644 --- a/configs/vit/vit_3d.py +++ b/configs/vit/vit_3d.py @@ -14,6 +14,7 @@ except: BATCH_SIZE = 512 IMG_SIZE = 32 +NUM_EPOCHS = 60 train_data = dict( dataset=dict( @@ -83,6 +84,14 @@ hooks = [ ), dict(type='LossHook'), dict(type='TensorboardHook', log_dir='./tfb_logs'), + dict( + type='LRSchedulerHook', + by_epoch=True, + lr_scheduler_cfg=dict( + type='LinearWarmupLR', + warmup_steps=5 + ) + ), # dict(type='SaveCheckpointHook', interval=5, checkpoint_dir='./ckpt'), # dict(type='LoadCheckpointHook', epoch=20, checkpoint_dir='./ckpt') ] @@ -97,13 +106,6 @@ fp16 = dict( initial_scale=2 ** 8 ) -lr_scheduler = dict( - type='LinearWarmupLR', - warmup_epochs=5 -) - -num_epochs = 60 - logging = dict( root_path='./logs' ) diff --git a/docs/colossalai/colossalai.engine.amp.amp_type.rst b/docs/colossalai/colossalai.engine.amp.amp_type.rst new file mode 100644 index 000000000..ec1afdfa6 --- /dev/null +++ b/docs/colossalai/colossalai.engine.amp.amp_type.rst @@ -0,0 +1,5 @@ +colossalai.engine.amp.amp\_type +=============================== + +.. automodule:: colossalai.engine.amp.amp_type + :members: diff --git a/docs/colossalai/colossalai.engine.amp.grad_scaler.rst b/docs/colossalai/colossalai.engine.amp.grad_scaler.rst new file mode 100644 index 000000000..752079eab --- /dev/null +++ b/docs/colossalai/colossalai.engine.amp.grad_scaler.rst @@ -0,0 +1,5 @@ +colossalai.engine.amp.grad\_scaler +================================== + +.. automodule:: colossalai.engine.amp.grad_scaler + :members: diff --git a/docs/colossalai/colossalai.engine.amp.rst b/docs/colossalai/colossalai.engine.amp.rst new file mode 100644 index 000000000..987f27f6a --- /dev/null +++ b/docs/colossalai/colossalai.engine.amp.rst @@ -0,0 +1,12 @@ +colossalai.engine.amp +===================== + +.. automodule:: colossalai.engine.amp + :members: + + +.. toctree:: + :maxdepth: 2 + + colossalai.engine.amp.amp_type + colossalai.engine.amp.grad_scaler diff --git a/docs/colossalai/colossalai.engine.amp_type.rst b/docs/colossalai/colossalai.engine.amp_type.rst deleted file mode 100644 index 8121b9933..000000000 --- a/docs/colossalai/colossalai.engine.amp_type.rst +++ /dev/null @@ -1,5 +0,0 @@ -colossalai.engine.amp\_type -=========================== - -.. automodule:: colossalai.engine.amp_type - :members: diff --git a/docs/colossalai/colossalai.engine.rst b/docs/colossalai/colossalai.engine.rst index 1cd4733b8..915be4c98 100644 --- a/docs/colossalai/colossalai.engine.rst +++ b/docs/colossalai/colossalai.engine.rst @@ -7,11 +7,6 @@ colossalai.engine .. toctree:: :maxdepth: 2 + colossalai.engine.amp colossalai.engine.gradient_handler colossalai.engine.schedule - - -.. toctree:: - :maxdepth: 2 - - colossalai.engine.amp_type diff --git a/docs/colossalai/colossalai.rst b/docs/colossalai/colossalai.rst index 414ee8120..a4d4656fd 100644 --- a/docs/colossalai/colossalai.rst +++ b/docs/colossalai/colossalai.rst @@ -21,7 +21,6 @@ colossalai .. toctree:: :maxdepth: 2 - colossalai.checkpointing colossalai.constants colossalai.core colossalai.initialize diff --git a/docs/colossalai/colossalai.utils.checkpointing.rst b/docs/colossalai/colossalai.utils.checkpointing.rst new file mode 100644 index 000000000..534a581d5 --- /dev/null +++ b/docs/colossalai/colossalai.utils.checkpointing.rst @@ -0,0 +1,5 @@ +colossalai.utils.checkpointing +============================== + +.. automodule:: colossalai.utils.checkpointing + :members: diff --git a/docs/colossalai/colossalai.utils.rst b/docs/colossalai/colossalai.utils.rst index bfe62172f..7f712e313 100644 --- a/docs/colossalai/colossalai.utils.rst +++ b/docs/colossalai/colossalai.utils.rst @@ -9,6 +9,7 @@ colossalai.utils :maxdepth: 2 colossalai.utils.activation_checkpoint + colossalai.utils.checkpointing colossalai.utils.common colossalai.utils.cuda colossalai.utils.memory diff --git a/docs/parallelization.md b/docs/parallelization.md index ca98d542b..0c1e70bfe 100644 --- a/docs/parallelization.md +++ b/docs/parallelization.md @@ -17,38 +17,40 @@ parallel = dict( ) ``` -The name of the dictionary variable should be **parallel**. All the arguments even **parallel** itself are optional and data, -pipeline, tensor parallel size will be set to defaulted value 1. The value of data, pipeline and tensor can be a int -representing the size of specific parallel dimension or a dictionary with a key called "size". The key "mode" +The name of the dictionary variable should be **parallel**. All the arguments even **parallel** itself are optional and +data, pipeline, tensor parallel size will be set to defaulted value 1. The value of data, pipeline and tensor can be a +int representing the size of specific parallel dimension or a dictionary with a key called "size". The key "mode" represents the way of tensor parallelism. ## Data Parallel -Data parallel is the most common way to distribute your training task by splitting data into several shards and train -on a single shard on each device. The configuration for data parallel is detected automatically and set for you. You do -not have to explicitly set them in your configurations. When data parallel size is larger than 1, Colossal-AI automatically +Data parallel is the most common way to distribute your training task by splitting data into several shards and train on +a single shard on each device. The configuration for data parallel is detected automatically and set for you. You do not +have to explicitly set them in your configurations. When data parallel size is larger than 1, Colossal-AI automatically adds the distributed data sampler to the dataloader to shard the dataset. ## 1D, 2D, 2.5D and 3D Parallel -To enable hybrid parallelism, we provide an array of tensor parallelism. We provide the list of papers which match each +To enable hybrid parallelism, we provide an array of tensor parallelism. We provide the list of papers which match each tensor parallel method. These parallel modes need to work with the distributed layers provided by Colossal-AI. -- 1D: [Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism](https://arxiv.org/abs/1909.08053) + +- +1D: [Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism](https://arxiv.org/abs/1909.08053) - 2D: [An Efficient 2D Method for Training Super-Large Deep Learning Models](https://arxiv.org/abs/2104.05343) -2D parallel relies on the SUMMA matrix multiplication algorithm and splits the input data, -model weights and layer outputs along two different dimensions. The tensor chunks are distributed over a 2D mesh of $P = N^2$ -devices where $N$ is the number of tensor chunks in a single dimension. + 2D parallel relies on the SUMMA matrix multiplication algorithm and splits the input data, model weights and layer + outputs along two different dimensions. The tensor chunks are distributed over a 2D mesh of $P = N^2$ devices where + $N$ is the number of tensor chunks in a single dimension. - 2.5D: [2.5-dimensional distributed model training](https://arxiv.org/abs/2105.14500) -Inspired by the 2.5D matrix multiplication algorithm, 2.5D parallel introduces a novel tensor parallelism which further -parallelizes 2D tensor parallelism. An amount of $P = N^2 โˆ— d$ processors are arranged into $d$ layers, -where each layer performs matrix multiplication operations independently with a dimension $N$. + Inspired by the 2.5D matrix multiplication algorithm, 2.5D parallel introduces a novel tensor parallelism which + further parallelizes 2D tensor parallelism. An amount of $P = N^2 โˆ— d$ processors are arranged into $d$ layers, where + each layer performs matrix multiplication operations independently with a dimension $N$. - 3D: [Maximizing Parallelism in Distributed Training for Huge Neural Networks](https://arxiv.org/abs/2105.14450) -We also introduce a 3D tensor parallelism that parallelizes neural networks on a 3D processor cube. This method achieves -the optimal, $O(P^{1/3})$ communication overhead on $P$ processors, while both computation and memory usage are evenly distributed -through optimized load balancing of parameters as well as activations. + We also introduce a 3D tensor parallelism that parallelizes neural networks on a 3D processor cube. This method + achieves the optimal, $O(P^{1/3})$ communication overhead on $P$ processors, while both computation and memory usage + are evenly distributed through optimized load balancing of parameters as well as activations. ```python # 1D parallel @@ -78,12 +80,12 @@ parallel = dict( ## Pipeline Parallel (experimental) -Pipeline parallelism is to split the model into several partitions by layer. For example, let's assume we have a simple -model which consists of two linear layer. We have two GPUs, and we can allocate the first linear layer to the first GPU +Pipeline parallelism is to split the model into several partitions by layer. For example, let's assume we have a simple +model which consists of two linear layer. We have two GPUs, and we can allocate the first linear layer to the first GPU and the second layer to the second GPU. This example of course wastes the computing resources and is only to demonstrate -the idea of pipeline parallelism. +the idea of pipeline parallelism. -As PyTorch is based on dynamic computation graph, the computation flow is not known until execution. To support pipeline +As PyTorch is based on dynamic computation graph, the computation flow is not known until execution. To support pipeline parallelism in PyTorch, you may need to add one more attribute, `layers_cfg` in your model class which tells Colossal-AI the sequence of execution. One example you can refer is `colossalai.nn.model.VanillaResNet`. @@ -192,9 +194,9 @@ class VanillaResNet(BaseModel): ] ``` -You can set the number of pipeline stages in your configuration file. When pipeline size is larger than 1, Colossal-AI -will automatically creates the pipeline schedule which defines the forward and backward step. You can specify how many microbatches -to run in each step in the `schedule` configuration. +You can set the number of pipeline stages in your configuration file. When pipeline size is larger than 1, Colossal-AI +will automatically creates the pipeline schedule which defines the forward and backward step. You can specify how many +microbatches to run in each step in the `schedule` configuration. ```python parallel = dict( @@ -206,10 +208,11 @@ schedule = dict( num_microbatches = 4 # set the number of microbatches per step ) ``` + This feature is still in development and is only experimental for now. ## Sequence Parallel (experimental) -Sequence parallel is to support long-sequence modelling such as document-level text understanding and medical imaging. -This method is proposed in [Sequence Parallelism: Making 4D Parallelism Possible](https://arxiv.org/abs/2105.13120). +Sequence parallel is to support long-sequence modelling such as document-level text understanding and medical imaging. +This method is proposed in [Sequence Parallelism: Making 4D Parallelism Possible](https://arxiv.org/abs/2105.13120). This feature is still in development and is only experimental for now. diff --git a/docs/run_demo.md b/docs/run_demo.md index 48f0590d3..6d8c5b49a 100644 --- a/docs/run_demo.md +++ b/docs/run_demo.md @@ -1,8 +1,8 @@ # Quick demo -Colossal-AI is an integrated large-scale deep learning system with efficient parallelization techniques. The system -can accelerate model training on distributed systems with multiple GPUs by applying parallelization techniques. The -system can also run on systems with only one GPU. Quick demos showing how to use Colossal-AI are given below. +Colossal-AI is an integrated large-scale deep learning system with efficient parallelization techniques. The system can +accelerate model training on distributed systems with multiple GPUs by applying parallelization techniques. The system +can also run on systems with only one GPU. Quick demos showing how to use Colossal-AI are given below. ## Single GPU @@ -32,25 +32,17 @@ realizes the training process. ```python import colossalai from colossalai.core import global_context as gpc -from colossalai.engine import Engine from colossalai.logging import get_global_dist_logger from colossalai.trainer import Trainer + def run_trainer(): - model, train_dataloader, test_dataloader, criterion, optimizer, schedule, lr_scheduler = colossalai.initialize() + engine, train_dataloader, test_dataloader = colossalai.initialize() logger = get_global_dist_logger() - schedule.data_sync = False - engine = Engine( - model=model, - criterion=criterion, - optimizer=optimizer, - lr_scheduler=lr_scheduler, - schedule=schedule - ) + logger.info("engine is built", ranks=[0]) trainer = Trainer(engine=engine, - hooks_cfg=gpc.config.hooks, verbose=True) logger.info("trainer is built", ranks=[0]) @@ -58,11 +50,13 @@ def run_trainer(): trainer.fit( train_dataloader=train_dataloader, test_dataloader=test_dataloader, - max_epochs=gpc.config.num_epochs, + epochs=gpc.config.num_epochs, + hooks_cfg=gpc.config.hooks, display_progress=True, test_interval=2 ) + if __name__ == '__main__': run_trainer() ``` @@ -72,9 +66,9 @@ Zoo. The detailed substitution process is elaborated [here](model.md). ## Features -Colossal-AI provides a collection of parallel training components for you. We aim to support you with your development of -distributed deep learning models just like how you write single-GPU deep learning models. We provide friendly tools to -kickstart distributed training in a few lines. +Colossal-AI provides a collection of parallel training components for you. We aim to support you with your development +of distributed deep learning models just like how you write single-GPU deep learning models. We provide friendly tools +to kickstart distributed training in a few lines. - [Data Parallelism](parallelization.md) - [Pipeline Parallelism](parallelization.md) diff --git a/docs/run_demo_zh.md b/docs/run_demo_zh.md index a52fcfd79..54839760d 100644 --- a/docs/run_demo_zh.md +++ b/docs/run_demo_zh.md @@ -4,40 +4,36 @@ Colossal-AIๆ˜ฏไธ€ไธชๅคง่ง„ๆจกๆทฑๅบฆๅญฆไน ็ณป็ปŸ๏ผŒๅ…ถไธญๅŒ…ๅซ้ซ˜ๆ•ˆ็š„ๅนถ่กŒๆŠ€ ## ๅ•GPU็ณป็ปŸ -ๅœจๅธฆๆœ‰GPU็š„้žๅˆ†ๅธƒๅผ็ณป็ปŸไธŠ่ฟ›่กŒๆจกๅž‹่ฎญ็ปƒๆ—ถ๏ผŒColossal-AIๅฏไปฅ่พพๅˆฐๅฝ“ๅ‰็š„ๅŸบ็บฟๆ•ˆ็Ž‡ใ€‚[่ฟ™้‡Œ](https://colab.research.google.com/drive/1fJnqqFzPuzZ_kn1lwCpG2nh3l2ths0KE?usp=sharing#scrollTo=cQ_y7lBG09LS)ๆˆ‘ไปฌ็ป™ๅ‡บไธ€ไธชGoogle Colab็คบไพ‹ๅฑ•็Žฐๅฆ‚ไฝ•ไฝฟ็”จColossal-AIไธŽCIFAR10ๆ•ฐๆฎ้›†ๅœจ้žๅˆ†ๅธƒๅผ็ณป็ปŸไธŠ่ฎญ็ปƒไธ€ไธชLeNetๆจกๅž‹ใ€‚ +ๅœจๅธฆๆœ‰GPU็š„้žๅˆ†ๅธƒๅผ็ณป็ปŸไธŠ่ฟ›่กŒๆจกๅž‹่ฎญ็ปƒๆ—ถ๏ผŒColossal-AIๅฏไปฅ่พพๅˆฐๅฝ“ๅ‰็š„ๅŸบ็บฟๆ•ˆ็Ž‡ใ€‚[่ฟ™้‡Œ](https://colab.research.google.com/drive/1fJnqqFzPuzZ_kn1lwCpG2nh3l2ths0KE?usp=sharing#scrollTo=cQ_y7lBG09LS)ๆˆ‘ไปฌ็ป™ๅ‡บไธ€ไธชGoogle +Colab็คบไพ‹ๅฑ•็Žฐๅฆ‚ไฝ•ไฝฟ็”จColossal-AIไธŽCIFAR10ๆ•ฐๆฎ้›†ๅœจ้žๅˆ†ๅธƒๅผ็ณป็ปŸไธŠ่ฎญ็ปƒไธ€ไธชLeNetๆจกๅž‹ใ€‚ ## ๅคšGPU็ณป็ปŸ -ๅœจๅคšGPU็š„ๅˆ†ๅธƒๅผ็ณป็ปŸไธŠ่ฎญ็ปƒๆทฑๅบฆๅญฆไน ๆจกๅž‹ๆ—ถ๏ผŒColossal-AIๅฏไปฅไฝฟ็”จ้ซ˜ๆ•ˆ็š„ๅนถ่กŒๆŠ€ๆœฏๆฅๆ˜พ่‘—ๅœฐๅŠ ้€Ÿ่ฎญ็ปƒ่ฟ‡็จ‹๏ผŒ่ฟ™ไบ›ๆŠ€ๆœฏๅฐ†ๅœจไธ‹้ข็š„[ๅนถ่กŒๆŠ€ๆœฏ](parallelization.md)็ซ ่Š‚ไธญ่ขซ่ฏฆ่ฟฐใ€‚ไธ‹้ข็š„ไปฃ็ ๅฐ†ๅœจๆ‹ฅๆœ‰ๅ››ไธชGPU็š„ๅˆ†ๅธƒๅผ็ณป็ปŸไธŠ่ฎญ็ปƒไธ€ไธชViTๆจกๅž‹๏ผŒๅ…ถไธญ`HOST`ๅ˜้‡ไธบๆ‚จๅˆ†ๅธƒๅผ็ณป็ปŸ็š„IPๅœฐๅ€ใ€‚่ฏทๆณจๆ„ไธ‹้ข็š„ไปฃ็ ไฝฟ็”จไบ†[Slurm](https://slurm.schedmd.com/documentation.html)ไฝœไธš่ฐƒๅบฆ็ณป็ปŸใ€‚ +ๅœจๅคšGPU็š„ๅˆ†ๅธƒๅผ็ณป็ปŸไธŠ่ฎญ็ปƒๆทฑๅบฆๅญฆไน ๆจกๅž‹ๆ—ถ๏ผŒColossal-AIๅฏไปฅไฝฟ็”จ้ซ˜ๆ•ˆ็š„ๅนถ่กŒๆŠ€ๆœฏๆฅๆ˜พ่‘—ๅœฐๅŠ ้€Ÿ่ฎญ็ปƒ่ฟ‡็จ‹๏ผŒ่ฟ™ไบ›ๆŠ€ๆœฏๅฐ†ๅœจไธ‹้ข็š„[ๅนถ่กŒๆŠ€ๆœฏ](parallelization.md) +็ซ ่Š‚ไธญ่ขซ่ฏฆ่ฟฐใ€‚ไธ‹้ข็š„ไปฃ็ ๅฐ†ๅœจๆ‹ฅๆœ‰ๅ››ไธชGPU็š„ๅˆ†ๅธƒๅผ็ณป็ปŸไธŠ่ฎญ็ปƒไธ€ไธชViTๆจกๅž‹๏ผŒๅ…ถไธญ`HOST` +ๅ˜้‡ไธบๆ‚จๅˆ†ๅธƒๅผ็ณป็ปŸ็š„IPๅœฐๅ€ใ€‚่ฏทๆณจๆ„ไธ‹้ข็š„ไปฃ็ ไฝฟ็”จไบ†[Slurm](https://slurm.schedmd.com/documentation.html)ไฝœไธš่ฐƒๅบฆ็ณป็ปŸใ€‚ ```bash HOST=xxx.xxx.xxx.xxx srun ./scripts/slurm_dist_train.sh ./examples/run_trainer.py ./configs/vit/vit_2d.py ``` -`./configs/vit/vit_2d.py`ๆ˜ฏไธ€ไธช[้…็ฝฎๆ–‡ไปถ](config.md)๏ผŒColossal-AIไฝฟ็”จ้…็ฝฎๆ–‡ไปถๆฅๅฎšไน‰่ฎญ็ปƒ่ฟ‡็จ‹ไธญ้œ€่ฆ็”จๅˆฐ็š„ๅ‚ๆ•ฐ๏ผŒๆฏ”ๅฆ‚ๆจกๅž‹็ฑปๅž‹ใ€ๆ•ฐๆฎ้›†ใ€ไปฅๅŠไผ˜ๅŒ–ๅ™จใ€ๅญฆไน ็Ž‡่ฐƒๅบฆๅ™จ็ญ‰ใ€‚ๆ‚จๅฏไปฅ้€š่ฟ‡็ผ–ๅ†™้…็ฝฎๆ–‡ไปถ็š„ๆ–นๅผๆฅ่ฎญ็ปƒไธๅŒ็š„ๆจกๅž‹ใ€‚`./examples/run_trainer.py`ๆ˜ฏไธ€ไธชๆ ‡ๅ‡†็š„่ฎญ็ปƒ่„šๆœฌ๏ผŒๅ…ทไฝ“ไปฃ็ ๅทฒ็ป้™„ๅœจไธ‹้ขใ€‚่ฏฅ่„šๆœฌๅฏไปฅ่ฏปๅ…ฅ้…็ฝฎๆ–‡ไปถไธญ็š„่ฎญ็ปƒๅ‚ๆ•ฐๅนถ่ฎญ็ปƒๆจกๅž‹ใ€‚ +`./configs/vit/vit_2d.py`ๆ˜ฏไธ€ไธช[้…็ฝฎๆ–‡ไปถ](config.md) +๏ผŒColossal-AIไฝฟ็”จ้…็ฝฎๆ–‡ไปถๆฅๅฎšไน‰่ฎญ็ปƒ่ฟ‡็จ‹ไธญ้œ€่ฆ็”จๅˆฐ็š„ๅ‚ๆ•ฐ๏ผŒๆฏ”ๅฆ‚ๆจกๅž‹็ฑปๅž‹ใ€ๆ•ฐๆฎ้›†ใ€ไปฅๅŠไผ˜ๅŒ–ๅ™จใ€ๅญฆไน ็Ž‡่ฐƒๅบฆๅ™จ็ญ‰ใ€‚ๆ‚จๅฏไปฅ้€š่ฟ‡็ผ–ๅ†™้…็ฝฎๆ–‡ไปถ็š„ๆ–นๅผๆฅ่ฎญ็ปƒไธๅŒ็š„ๆจกๅž‹ใ€‚`./examples/run_trainer.py` +ๆ˜ฏไธ€ไธชๆ ‡ๅ‡†็š„่ฎญ็ปƒ่„šๆœฌ๏ผŒๅ…ทไฝ“ไปฃ็ ๅทฒ็ป้™„ๅœจไธ‹้ขใ€‚่ฏฅ่„šๆœฌๅฏไปฅ่ฏปๅ…ฅ้…็ฝฎๆ–‡ไปถไธญ็š„่ฎญ็ปƒๅ‚ๆ•ฐๅนถ่ฎญ็ปƒๆจกๅž‹ใ€‚ ```python import colossalai from colossalai.core import global_context as gpc -from colossalai.engine import Engine from colossalai.logging import get_global_dist_logger from colossalai.trainer import Trainer + def run_trainer(): - model, train_dataloader, test_dataloader, criterion, optimizer, schedule, lr_scheduler = colossalai.initialize() + engine, train_dataloader, test_dataloader = colossalai.initialize() logger = get_global_dist_logger() - schedule.data_sync = False - engine = Engine( - model=model, - criterion=criterion, - optimizer=optimizer, - lr_scheduler=lr_scheduler, - schedule=schedule - ) logger.info("engine is built", ranks=[0]) trainer = Trainer(engine=engine, - hooks_cfg=gpc.config.hooks, verbose=True) logger.info("trainer is built", ranks=[0]) @@ -45,11 +41,13 @@ def run_trainer(): trainer.fit( train_dataloader=train_dataloader, test_dataloader=test_dataloader, - max_epochs=gpc.config.num_epochs, + epochs=gpc.config.num_epochs, + hooks_cfg=gpc.config.hooks, display_progress=True, test_interval=2 ) + if __name__ == '__main__': run_trainer() ``` diff --git a/docs/trainer_engine.md b/docs/trainer_engine.md index 276134021..88b872826 100644 --- a/docs/trainer_engine.md +++ b/docs/trainer_engine.md @@ -2,9 +2,9 @@ ## Build your engine -To better understand how `Engine` class works, let's start from the conception of the process function in common engines. The process function -usually controls the behavior over a batch of a dataset, `Engine` class just controls the process function. Here we give a standard process -function in the following code block. +To better understand how `Engine` class works, let's start from the conception of the process function in common +engines. The process function usually controls the behavior over a batch of a dataset, `Engine` class just controls the +process function. Here we give a standard process function in the following code block. ```python def process_function(dataloader, model, criterion, optim): @@ -16,32 +16,33 @@ def process_function(dataloader, model, criterion, optim): optim.setp() ``` -In `ignite.engine` or `keras.engine`, the process function is always provided by users. However, it is tricky for users to write their own process -functions for pipeline parallelism. Aiming at offering accessible hybrid parallelism for users, we provide the powerful `Engine` class. This class -enables pipeline parallelism and offers one-forward-one-backward non-interleaving strategy. Also, you can use pre-defined learning rate scheduler -in the `Engine` class to adjust learning rate during training. +In `ignite.engine` or `keras.engine`, the process function is always provided by users. However, it is tricky for users +to write their own process functions for pipeline parallelism. Aiming at offering accessible hybrid parallelism for +users, we provide the powerful `Engine` class. This class enables pipeline parallelism and offers +one-forward-one-backward non-interleaving strategy. Also, you can use pre-defined learning rate scheduler in +the `Engine` class to adjust learning rate during training. -In order to build your engine, just set variables `model`, `criterion`, `optimizer`, `lr_scheduler` and `schedule`. The following code block provides -an example. +In order to build your engine, just set variables `model`, `criterion`, `optimizer`, `lr_scheduler` and `schedule`. The +following code block provides an example. **The engine is automatically created from the config file for you if you +start with `colossalai.initialize`.** ```python import torch import torch.nn as nn import torchvision.models as models import colossalai +from colossalai.engine import Engine model = models.resnet18() criterion = nn.CrossEntropyLoss() -optimizer = torch.optim.Adam(model) -lr_scheduler = colossalai.nn.lr_scheduler.CosineAnnealingLR(optimizer, 1000) -schedule = colossalai.engine.schedule.NoPipelineSchedule() +optimizer = torch.optim.Adam(model.parameters()) +schedule = colossalai.engine.NoPipelineSchedule() MyEngine = Engine( model=model, criterion=criterion, optimizer=optimizer, - lr_scheduler=lr_scheduler, - schedule=schedule + step_schedule=schedule ) ``` @@ -51,21 +52,24 @@ More information regarding the class can be found in the API references. ### Overview -To learn how to customize a trainer which meets your needs, let's first give a look at the `Trainer` class. We highly recommend that you read *Get Started* +To learn how to customize a trainer which meets your needs, let's first give a look at the `Trainer` class. We highly +recommend that you read *Get Started* section and *Build your engine* first. -The `Trainer` class enables researchers and engineers to use our system more conveniently. Instead of having to write your own scripts, you can simply -construct your own trainer by calling the `Trainer` class, just like what we did in the following code block. +The `Trainer` class enables researchers and engineers to use our system more conveniently. Instead of having to write +your own scripts, you can simply construct your own trainer by calling the `Trainer` class, just like what we did in the +following code block. ```python -MyTrainer = Trainer(MyEngine) +MyTrainer = Trainer(my_engine) ``` -After that, you can use the `fit` method to train or evaluate your model. In order to make our `Trainer` class even more powerful, we incorporate a set of -handy tools to the class. For example, you can monitor or record the running states and metrics which indicate the current performance of the model. These -functions are realized by hooks. The `BasicHook` class allows you to execute your hook functions at specified time. We have already created some practical -hooks for you, as listed below. What you need to do is just picking the right ones which suit your needs. Detailed descriptions of the class can be found -in the API references. +After that, you can use the `fit` method to train or evaluate your model. In order to make our `Trainer` class even more +powerful, we incorporate a set of handy tools to the class. For example, you can monitor or record the running states +and metrics which indicate the current performance of the model. These functions are realized by hooks. The `BasicHook` +class allows you to execute your hook functions at specified time. We have already created some practical hooks for you, +as listed below. What you need to do is just picking the right ones which suit your needs. Detailed descriptions of the +class can be found in the API references. ```python hooks = [ @@ -80,18 +84,21 @@ hooks = [ ] ``` -These hook functions will record metrics, elapsed time and memory usage and write them to log after each epoch. Besides, they print the current loss and -accuracy to let users monitor the performance of the model. +These hook functions will record metrics, elapsed time and memory usage and write them to log after each epoch. Besides, +they print the current loss and accuracy to let users monitor the performance of the model. ### Hook -If you have your specific needs, feel free to extend our `BaseHook` class to add your own functions, or our `MetricHook` class to write a metric collector. -These hook functions can be called at twelve timing in the trainer's life cycle. Besides, you can define the priorities of all hooks to arrange the execution order of them. -More information can be found in the API references. +If you have your specific needs, feel free to extend our `BaseHook` class to add your own functions, or our `MetricHook` +class to write a metric collector. These hook functions can be called at twelve timing in the trainer's life cycle. +Besides, you can define the priorities of all hooks to arrange the execution order of them. More information can be +found in the API references. ### Metric -You can write your own metrics by extending our `Metric` class. It should be used with the `MetricHook` class. When your write your own metric hooks, please set -the priority carefully and make sure the hook is called before other hooks which might require the results of the metric hook. +You can write your own metrics by extending our `Metric` class. It should be used with the `MetricHook` class. When your +write your own metric hooks, please set the priority carefully and make sure the hook is called before other hooks which +might require the results of the metric hook. -We've already provided some metric hooks and we store metric objects in `runner.states['metrics']`. It is a dictionary and metrics can be accessed by their names. +We've already provided some metric hooks and we store metric objects in `runner.states['metrics']`. It is a dictionary +and metrics can be accessed by their names. diff --git a/docs/trainer_engine_zh.md b/docs/trainer_engine_zh.md index 0e2df3fdd..737d6745b 100644 --- a/docs/trainer_engine_zh.md +++ b/docs/trainer_engine_zh.md @@ -14,28 +14,30 @@ def process_function(dataloader, model, criterion, optim): optim.setp() ``` -ๅœจ`ignite.engine`ไธŽ`keras.engine`ไธญ๏ผŒ่ฟ›็จ‹ๅ‡ฝๆ•ฐ้œ€่ฆ็”ฑ็”จๆˆทๆไพ›๏ผŒ็„ถ่€Œ๏ผŒ็”จๆˆทๅพˆ้šพไธบๆตๆฐด็บฟๅนถ่กŒ็ผ–ๅ†™่ฟ›็จ‹ๅ‡ฝๆ•ฐใ€‚ไธบไบ†ๅ‘็”จๆˆทๆไพ›ๆ–นไพฟ็š„ๆททๅˆๅนถ่กŒ๏ผŒๆˆ‘ไปฌๆไพ›ไบ†ๅ…ทๅค‡ๅผบๅคงๅŠŸ่ƒฝ็š„`Engine`็ฑป๏ผŒ่ฏฅ็ฑปๆ”ฏๆŒๆตๆฐด็บฟๅนถ่กŒ๏ผŒๅนถๆไพ›ๅ‰ๅ‘ไผ ๆ’ญๅŽๅ‘ไผ ๆ’ญไธไบค็ป‡็š„็ญ–็•ฅใ€‚ๅŒๆ—ถ๏ผŒๆ‚จๅฏไปฅๅœจ`Engine`็ฑปไธญไฝฟ็”จๆ‚จไบ‹ๅ…ˆๅฎšไน‰ๅฅฝ็š„ๅญฆไน ็Ž‡่ฐƒๅบฆๅ™จๆฅๅœจ่ฎญ็ปƒ่ฟ‡็จ‹ไธญ่ฐƒๆ•ดๅญฆไน ็Ž‡ใ€‚ +ๅœจ`ignite.engine`ไธŽ`keras.engine`ไธญ๏ผŒ่ฟ›็จ‹ๅ‡ฝๆ•ฐ้œ€่ฆ็”ฑ็”จๆˆทๆไพ›๏ผŒ็„ถ่€Œ๏ผŒ็”จๆˆทๅพˆ้šพไธบๆตๆฐด็บฟๅนถ่กŒ็ผ–ๅ†™่ฟ›็จ‹ๅ‡ฝๆ•ฐใ€‚ไธบไบ†ๅ‘็”จๆˆทๆไพ›ๆ–นไพฟ็š„ๆททๅˆๅนถ่กŒ๏ผŒๆˆ‘ไปฌๆไพ›ไบ†ๅ…ทๅค‡ๅผบๅคงๅŠŸ่ƒฝ็š„`Engine` +็ฑป๏ผŒ่ฏฅ็ฑปๆ”ฏๆŒๆตๆฐด็บฟๅนถ่กŒ๏ผŒๅนถๆไพ›ๅ‰ๅ‘ไผ ๆ’ญๅŽๅ‘ไผ ๆ’ญไธไบค็ป‡็š„็ญ–็•ฅใ€‚ๅŒๆ—ถ๏ผŒๆ‚จๅฏไปฅๅœจ`Engine`็ฑปไธญไฝฟ็”จๆ‚จไบ‹ๅ…ˆๅฎšไน‰ๅฅฝ็š„ๅญฆไน ็Ž‡่ฐƒๅบฆๅ™จๆฅๅœจ่ฎญ็ปƒ่ฟ‡็จ‹ไธญ่ฐƒๆ•ดๅญฆไน ็Ž‡ใ€‚ ๆ‚จๅœจๆž„้€ ๅผ•ๆ“Žๆ—ถๅช้œ€่ฆๅฎšไน‰`model`ใ€`criterion`ใ€`optimizer`ใ€`lr_scheduler`ไธŽ`schedule`็ญ‰ๅ˜้‡ๅณๅฏ๏ผŒไธ‹้ข็š„ไปฃ็ ๅ—็ป™ๅ‡บไบ†ไธ€ไธช่ฟ™ๆ ท็š„ไพ‹ๅญใ€‚ +**ๅฆ‚ๆžœไฝ ไฝฟ็”จ`colossalai.initialize`็š„่ฏ๏ผŒengineไผšไปŽconfigๆ–‡ไปถ้‡Œ่‡ชๅŠจๆž„ๅปบใ€‚** ```python import torch import torch.nn as nn import torchvision.models as models import colossalai +from colossalai.engine import Engine model = models.resnet18() criterion = nn.CrossEntropyLoss() optimizer = torch.optim.Adam(model) lr_scheduler = colossalai.nn.lr_scheduler.CosineAnnealingLR(optimizer, 1000) -schedule = colossalai.engine.schedule.NoPipelineSchedule() +schedule = colossalai.engine.NoPipelineSchedule() MyEngine = Engine( model=model, criterion=criterion, optimizer=optimizer, - lr_scheduler=lr_scheduler, - schedule=schedule + step_schedule=schedule ) ``` @@ -48,10 +50,12 @@ MyEngine = Engine( `Trainer`็ฑปๆ—จๅœจ่ฎฉ็ง‘็ ”ๅทฅไฝœ่€…ๅ’Œๅทฅ็จ‹ๅธˆๆ›ดๅŠ ๆ–นไพฟๅœฐไฝฟ็”จๆˆ‘ไปฌ็š„็ณป็ปŸ๏ผŒๆ‚จไธ้œ€่ฆ่‡ชๅทฑๅ†™่„šๆœฌ๏ผŒๅช้œ€่ฆ่ฐƒ็”จ`Trainer`็ฑปๆฅๆž„้€ ๆ‚จ็š„่ฎญ็ปƒๅ™จๅณๅฏ๏ผŒๅฐฑๅƒไธ‹้ข็š„ไปฃ็ ๅ—ไธญๆ‰€ๅš็š„ใ€‚ ```python -MyTrainer = Trainer(MyEngine) +MyTrainer = Trainer(my_trainer) ``` -ๅœจๆญคไน‹ๅŽ๏ผŒๆ‚จๅฏไปฅไฝฟ็”จ`fit`ๆ–นๆณ•ๆฅ่ฎญ็ปƒๆˆ–่ฐƒ็”จๆ‚จ็š„ๆจกๅž‹ใ€‚้™คๆญคไน‹ๅค–๏ผŒไธบไบ†่ฎฉๆˆ‘ไปฌ็š„`Trainer`็ฑปๆ‹ฅๆœ‰ๆ›ดๅผบๅคง็š„ๅŠŸ่ƒฝ๏ผŒๆˆ‘ไปฌๅŠ ๅ…ฅไบ†ไธ€็ณปๅˆ—ๆ–นไพฟๆ‚จไฝฟ็”จ็š„ๅทฅๅ…ทใ€‚ไพ‹ๅฆ‚๏ผŒๆ‚จๅฏไปฅๅœจ่ฎญ็ปƒ่ฟ‡็จ‹ไธญๆŒ็ปญ็›‘ๆต‹ๅนถ่ฎฐๅฝ•ๆจกๅž‹็›ฎๅ‰็š„่ฟ่กŒ็Šถๆ€ๅ’Œ่กจ็Žฐ๏ผŒ่ฟ™ไบ›ๅŠŸ่ƒฝ้ƒฝๆ˜ฏ้€š่ฟ‡้’ฉๅญๅ‡ฝๆ•ฐๆฅๅฎž็Žฐ็š„ใ€‚ๆˆ‘ไปฌๆไพ›็š„`BasicHook`็ฑป่ฎฉๆ‚จๅฏไปฅๅœจๆŒ‡ๅฎšๆ—ถ้—ดๆ‰ง่กŒๆ‚จ็š„้’ฉๅญๅ‡ฝๆ•ฐใ€‚ๅฆ‚ไธ‹ๆ–น็š„ไปฃ็ ๅ—ๆ‰€็คบ๏ผŒๆˆ‘ไปฌไบ‹ๅ…ˆไธบๆ‚จๅฎšไน‰ๅฅฝไบ†ไธ€ไบ›ๅฎž็”จ็š„้’ฉๅญๅ‡ฝๆ•ฐ๏ผŒๆ‚จ้œ€่ฆๅš็š„ๅฐฑๆ˜ฏๆ‰พๅˆฐ็ฌฆๅˆๆ‚จ้œ€ๆฑ‚็š„้’ฉๅญๅ‡ฝๆ•ฐใ€‚ๆ›ดๅคš่ฏฅ็ฑป็š„็›ธๅ…ณไฟกๆฏๅฏไปฅๅœจAPIไฟกๆฏไธญๆ‰พๅˆฐใ€‚ +ๅœจๆญคไน‹ๅŽ๏ผŒๆ‚จๅฏไปฅไฝฟ็”จ`fit`ๆ–นๆณ•ๆฅ่ฎญ็ปƒๆˆ–่ฐƒ็”จๆ‚จ็š„ๆจกๅž‹ใ€‚้™คๆญคไน‹ๅค–๏ผŒไธบไบ†่ฎฉๆˆ‘ไปฌ็š„`Trainer` +็ฑปๆ‹ฅๆœ‰ๆ›ดๅผบๅคง็š„ๅŠŸ่ƒฝ๏ผŒๆˆ‘ไปฌๅŠ ๅ…ฅไบ†ไธ€็ณปๅˆ—ๆ–นไพฟๆ‚จไฝฟ็”จ็š„ๅทฅๅ…ทใ€‚ไพ‹ๅฆ‚๏ผŒๆ‚จๅฏไปฅๅœจ่ฎญ็ปƒ่ฟ‡็จ‹ไธญๆŒ็ปญ็›‘ๆต‹ๅนถ่ฎฐๅฝ•ๆจกๅž‹็›ฎๅ‰็š„่ฟ่กŒ็Šถๆ€ๅ’Œ่กจ็Žฐ๏ผŒ่ฟ™ไบ›ๅŠŸ่ƒฝ้ƒฝๆ˜ฏ้€š่ฟ‡้’ฉๅญๅ‡ฝๆ•ฐๆฅๅฎž็Žฐ็š„ใ€‚ๆˆ‘ไปฌๆไพ›็š„`BasicHook` +็ฑป่ฎฉๆ‚จๅฏไปฅๅœจๆŒ‡ๅฎšๆ—ถ้—ดๆ‰ง่กŒๆ‚จ็š„้’ฉๅญๅ‡ฝๆ•ฐใ€‚ๅฆ‚ไธ‹ๆ–น็š„ไปฃ็ ๅ—ๆ‰€็คบ๏ผŒๆˆ‘ไปฌไบ‹ๅ…ˆไธบๆ‚จๅฎšไน‰ๅฅฝไบ†ไธ€ไบ›ๅฎž็”จ็š„้’ฉๅญๅ‡ฝๆ•ฐ๏ผŒๆ‚จ้œ€่ฆๅš็š„ๅฐฑๆ˜ฏๆ‰พๅˆฐ็ฌฆๅˆๆ‚จ้œ€ๆฑ‚็š„้’ฉๅญๅ‡ฝๆ•ฐใ€‚ๆ›ดๅคš่ฏฅ็ฑป็š„็›ธๅ…ณไฟกๆฏๅฏไปฅๅœจAPIไฟกๆฏไธญๆ‰พๅˆฐใ€‚ ```python hooks = [ @@ -70,7 +74,8 @@ hooks = [ ### ้’ฉๅญๅ‡ฝๆ•ฐ -ๅฆ‚ๆžœๆ‚จๆœ‰ไธชๆ€งๅŒ–้œ€ๆฑ‚๏ผŒๆ‚จๅฏไปฅ็ปงๆ‰ฟๆˆ‘ไปฌ็š„`BaseHook`็ฑปๅนถๆทปๅŠ ๆ‚จ็š„้’ฉๅญๅ‡ฝๆ•ฐ๏ผŒๆˆ–่€…็ปงๆ‰ฟๆˆ‘ไปฌ็š„`MetricHook`ๆฅ็ผ–ๅ†™ๆ‚จ้œ€่ฆ็š„ๅบฆ้‡ๆ ‡ๅ‡†ใ€‚่ฟ™ไบ›้’ฉๅญๅ‡ฝๆ•ฐๅฏไปฅๅœจ`Trainer`็”Ÿๅ‘ฝๅ‘จๆœŸ็š„12ไธชๆ—ถ้—ด็‚น่ขซๆ‰ง่กŒใ€‚ๆ›ดๅคš่ฏฅ็ฑป็š„็›ธๅ…ณไฟกๆฏๅฏไปฅๅœจAPIไฟกๆฏไธญๆ‰พๅˆฐใ€‚ +ๅฆ‚ๆžœๆ‚จๆœ‰ไธชๆ€งๅŒ–้œ€ๆฑ‚๏ผŒๆ‚จๅฏไปฅ็ปงๆ‰ฟๆˆ‘ไปฌ็š„`BaseHook`็ฑปๅนถๆทปๅŠ ๆ‚จ็š„้’ฉๅญๅ‡ฝๆ•ฐ๏ผŒๆˆ–่€…็ปงๆ‰ฟๆˆ‘ไปฌ็š„`MetricHook`ๆฅ็ผ–ๅ†™ๆ‚จ้œ€่ฆ็š„ๅบฆ้‡ๆ ‡ๅ‡†ใ€‚่ฟ™ไบ›้’ฉๅญๅ‡ฝๆ•ฐๅฏไปฅๅœจ`Trainer` +็”Ÿๅ‘ฝๅ‘จๆœŸ็š„12ไธชๆ—ถ้—ด็‚น่ขซๆ‰ง่กŒใ€‚ๆ›ดๅคš่ฏฅ็ฑป็š„็›ธๅ…ณไฟกๆฏๅฏไปฅๅœจAPIไฟกๆฏไธญๆ‰พๅˆฐใ€‚ ### ๅบฆ้‡ๆ ‡ๅ‡† diff --git a/examples/colossal_cifar_demo.ipynb b/examples/colossal_cifar_demo.ipynb index 2ad9022c9..221707bbb 100644 --- a/examples/colossal_cifar_demo.ipynb +++ b/examples/colossal_cifar_demo.ipynb @@ -1,370 +1,370 @@ { - "nbformat": 4, - "nbformat_minor": 0, - "metadata": { - "colab": { - "name": "colossal_cifar_demo.ipynb", - "provenance": [] - }, - "kernelspec": { - "name": "python3", - "display_name": "Python 3" - }, - "language_info": { - "name": "python" - }, - "accelerator": "GPU" + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "name": "colossal_cifar_demo.ipynb", + "provenance": [] }, - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "id": "uhrbvVEh2iJd" - }, - "source": [ - "# Train an image classifier\n" - ] + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "language_info": { + "name": "python" + }, + "accelerator": "GPU" + }, + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "uhrbvVEh2iJd" + }, + "source": [ + "# Train an image classifier\n" + ] + }, + { + "cell_type": "code", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" }, + "id": "vP7LvCpG23a2", + "outputId": "b37f7203-8a02-4736-c527-603f2bb34d7d" + }, + "source": [ + "!pip install ColossalAI deepspeed" + ], + "execution_count": null, + "outputs": [ { - "cell_type": "code", - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "vP7LvCpG23a2", - "outputId": "b37f7203-8a02-4736-c527-603f2bb34d7d" - }, - "source": [ - "!pip install ColossalAI deepspeed" - ], - "execution_count": null, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Requirement already satisfied: ColossalAI in /usr/local/lib/python3.7/dist-packages (0.1)\n", - "Requirement already satisfied: deepspeed in /usr/local/lib/python3.7/dist-packages (0.5.4)\n", - "Requirement already satisfied: packaging in /usr/local/lib/python3.7/dist-packages (from deepspeed) (21.0)\n", - "Requirement already satisfied: triton in /usr/local/lib/python3.7/dist-packages (from deepspeed) (1.1.1)\n", - "Requirement already satisfied: tqdm in /usr/local/lib/python3.7/dist-packages (from deepspeed) (4.62.3)\n", - "Requirement already satisfied: numpy in /usr/local/lib/python3.7/dist-packages (from deepspeed) (1.19.5)\n", - "Requirement already satisfied: tensorboardX==1.8 in /usr/local/lib/python3.7/dist-packages (from deepspeed) (1.8)\n", - "Requirement already satisfied: ninja in /usr/local/lib/python3.7/dist-packages (from deepspeed) (1.10.2.2)\n", - "Requirement already satisfied: torch in /usr/local/lib/python3.7/dist-packages (from deepspeed) (1.9.0+cu111)\n", - "Requirement already satisfied: psutil in /usr/local/lib/python3.7/dist-packages (from deepspeed) (5.4.8)\n", - "Requirement already satisfied: protobuf>=3.2.0 in /usr/local/lib/python3.7/dist-packages (from tensorboardX==1.8->deepspeed) (3.17.3)\n", - "Requirement already satisfied: six in /usr/local/lib/python3.7/dist-packages (from tensorboardX==1.8->deepspeed) (1.15.0)\n", - "Requirement already satisfied: pyparsing>=2.0.2 in /usr/local/lib/python3.7/dist-packages (from packaging->deepspeed) (2.4.7)\n", - "Requirement already satisfied: typing-extensions in /usr/local/lib/python3.7/dist-packages (from torch->deepspeed) (3.7.4.3)\n", - "Requirement already satisfied: filelock in /usr/local/lib/python3.7/dist-packages (from triton->deepspeed) (3.3.0)\n" - ] - } - ] - }, - { - "cell_type": "code", - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "UVKEurtS4SFS", - "outputId": "99fb6050-5da7-4f27-b4eb-9b3ccf830efb" - }, - "source": [ - "import colossalai\n", - "from colossalai.engine import Engine, NoPipelineSchedule\n", - "from colossalai.trainer import Trainer\n", - "from colossalai.context import Config\n", - "import torch" - ], - "execution_count": null, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Please install apex to use FP16 Optimizer\n", - "Apex should be installed to use the FP16 optimizer\n", - "apex is required for mixed precision training\n" - ] - } - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "PpFfhNBD7NSn" - }, - "source": [ - "First, we should initialize distributed environment. Though we just use single GPU in this example, we still need initialize distributed environment for compatibility. We just consider the simplest case here, so we just set the number of parallel processes to 1." - ] - }, - { - "cell_type": "code", - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "8yF7Lc-K7NAS", - "outputId": "01312349-a8b0-4de4-9103-7d1b48e6cc36" - }, - "source": [ - "parallel_cfg = Config(dict(parallel=dict(\n", - " data=dict(size=1),\n", - " pipeline=dict(size=1),\n", - " tensor=dict(size=1, mode=None),\n", - ")))\n", - "colossalai.init_dist(config=parallel_cfg,\n", - " local_rank=0,\n", - " world_size=1,\n", - " host='127.0.0.1',\n", - " port=8888,\n", - " backend='nccl')" - ], - "execution_count": null, - "outputs": [ - { - "output_type": "stream", - "name": "stderr", - "text": [ - "colossalai - torch.distributed.distributed_c10d - 2021-10-15 03:27:51,596 INFO: Added key: store_based_barrier_key:1 to store for rank: 0\n", - "colossalai - torch.distributed.distributed_c10d - 2021-10-15 03:27:51,598 INFO: Rank 0: Completed store-based barrier for 1 nodes.\n", - "colossalai - torch.distributed.distributed_c10d - 2021-10-15 03:27:51,602 INFO: Added key: store_based_barrier_key:2 to store for rank: 0\n", - "colossalai - torch.distributed.distributed_c10d - 2021-10-15 03:27:51,605 INFO: Rank 0: Completed store-based barrier for 1 nodes.\n", - "colossalai - torch.distributed.distributed_c10d - 2021-10-15 03:27:51,608 INFO: Added key: store_based_barrier_key:3 to store for rank: 0\n", - "colossalai - torch.distributed.distributed_c10d - 2021-10-15 03:27:51,610 INFO: Rank 0: Completed store-based barrier for 1 nodes.\n" - ] - }, - { - "output_type": "stream", - "name": "stdout", - "text": [ - "process rank 0 is bound to device 0\n", - "initialized seed on rank 0, numpy: 1024, python random: 1024, ParallelMode.DATA: 1024, ParallelMode.TENSOR: 1124,the default parallel seed is ParallelMode.DATA.\n" - ] - } - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "ppjmMxc_81TK" - }, - "source": [ - "Load and normalize the CIFAR10 training and test datasets using `colossalai.nn.data`. Note that we have wrapped `torchvision.transforms`, so that we can simply use the config dict to use them." - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "ZyGhyD47-dUY", - "colab": { - "base_uri": "https://localhost:8080/" - }, - "outputId": "98bbf2d1-a1c4-4bb4-b6df-600777b1e8f5" - }, - "source": [ - "transform_cfg = [\n", - " dict(type='ToTensor'),\n", - " dict(type='Normalize',\n", - " mean=[0.4914, 0.4822, 0.4465],\n", - " std=[0.2023, 0.1994, 0.2010]),\n", - "]\n", - "\n", - "batch_size = 128\n", - "\n", - "trainset = colossalai.nn.data.CIFAR10Dataset(transform_cfg, root='./data', train=True)\n", - "trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2)\n", - "\n", - "testset = colossalai.nn.data.CIFAR10Dataset(transform_cfg, root='./data', train=False)\n", - "testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=2)" - ], - "execution_count": null, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Files already downloaded and verified\n", - "Files already downloaded and verified\n" - ] - } - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "NvPbfLLR9NzC" - }, - "source": [ - "We just define a simple Convolutional Neural Network here." - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "cQ_y7lBG09LS" - }, - "source": [ - "import torch.nn as nn\n", - "import torch.nn.functional as F\n", - "\n", - "\n", - "class Net(nn.Module):\n", - " def __init__(self):\n", - " super().__init__()\n", - " self.conv1 = nn.Conv2d(3, 6, 5)\n", - " self.pool = nn.MaxPool2d(2, 2)\n", - " self.conv2 = nn.Conv2d(6, 16, 5)\n", - " self.fc1 = nn.Linear(16 * 5 * 5, 120)\n", - " self.fc2 = nn.Linear(120, 84)\n", - " self.fc3 = nn.Linear(84, 10)\n", - "\n", - " def forward(self, x):\n", - " x = self.pool(F.relu(self.conv1(x)))\n", - " x = self.pool(F.relu(self.conv2(x)))\n", - " x = torch.flatten(x, 1) # flatten all dimensions except batch\n", - " x = F.relu(self.fc1(x))\n", - " x = F.relu(self.fc2(x))\n", - " x = self.fc3(x)\n", - " return x\n", - "\n", - "\n", - "model = Net().cuda()" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "tgsszAmM9dYZ" - }, - "source": [ - "Define a Loss function and optimizer. And then we use them to initialize `Engine` and `Trainer`. We provide various training / evaluating hooks. In this case, we just use the simplest hooks which can compute and print loss and accuracy." - ] - }, - { - "cell_type": "code", - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "YtaDoCax1BCf", - "outputId": "b33b1641-03d8-4597-c8c2-1a4c1d61e9b0" - }, - "source": [ - "import torch.optim as optim\n", - "\n", - "criterion = nn.CrossEntropyLoss()\n", - "optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)\n", - "schedule = NoPipelineSchedule()\n", - "engine = Engine(\n", - " model=model,\n", - " criterion=criterion,\n", - " optimizer=optimizer,\n", - " lr_scheduler=None,\n", - " schedule=schedule\n", - " )\n", - "trainer = Trainer(engine=engine,\n", - " hooks_cfg=[dict(type='LossHook'), dict(type='LogMetricByEpochHook'), dict(type='AccuracyHook')],\n", - " verbose=True)" - ], - "execution_count": null, - "outputs": [ - { - "output_type": "stream", - "name": "stderr", - "text": [ - "colossalai - rank_0 - 2021-10-15 03:27:56,018 WARNING: No gradient handler is set up, please make sure you do not need to all-reduce the gradients after a training step.\n", - "colossalai - rank_0 - 2021-10-15 03:27:56,024 INFO: build LogMetricByEpochHook for train, priority = 1\n", - "colossalai - rank_0 - 2021-10-15 03:27:56,026 INFO: build LossHook for train, priority = 10\n", - "colossalai - rank_0 - 2021-10-15 03:27:56,029 INFO: build AccuracyHook for train, priority = 10\n" - ] - } - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "_JR2TuvH99Ik" - }, - "source": [ - "Then we set training configs. We train our model for 10 epochs and it will be evaluated every 1 epoch. Set `display_progress` to `True` to display the training / evaluating progress bar." - ] - }, - { - "cell_type": "code", - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "w-J3IP-J1sfx", - "outputId": "bdb76939-04f1-4124-ce5e-3af44c0d902c" - }, - "source": [ - "num_epochs = 10\n", - "test_interval = 1\n", - "trainer.fit(\n", - " train_dataloader=trainloader,\n", - " test_dataloader=testloader,\n", - " max_epochs=num_epochs,\n", - " display_progress=True,\n", - " test_interval=test_interval\n", - " )" - ], - "execution_count": null, - "outputs": [ - { - "output_type": "stream", - "name": "stderr", - "text": [ - "[Epoch 0 train]: 0%| | 0/391 [00:00=3.2.0 in /usr/local/lib/python3.7/dist-packages (from tensorboardX==1.8->deepspeed) (3.17.3)\n", + "Requirement already satisfied: six in /usr/local/lib/python3.7/dist-packages (from tensorboardX==1.8->deepspeed) (1.15.0)\n", + "Requirement already satisfied: pyparsing>=2.0.2 in /usr/local/lib/python3.7/dist-packages (from packaging->deepspeed) (2.4.7)\n", + "Requirement already satisfied: typing-extensions in /usr/local/lib/python3.7/dist-packages (from torch->deepspeed) (3.7.4.3)\n", + "Requirement already satisfied: filelock in /usr/local/lib/python3.7/dist-packages (from triton->deepspeed) (3.3.0)\n" + ] } - ] + ] + }, + { + "cell_type": "code", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "UVKEurtS4SFS", + "outputId": "99fb6050-5da7-4f27-b4eb-9b3ccf830efb" + }, + "source": [ + "import colossalai\n", + "from colossalai.engine import Engine, NoPipelineSchedule\n", + "from colossalai.trainer import Trainer\n", + "from colossalai.context import Config\n", + "import torch" + ], + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Please install apex to use FP16 Optimizer\n", + "Apex should be installed to use the FP16 optimizer\n", + "apex is required for mixed precision training\n" + ] + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "PpFfhNBD7NSn" + }, + "source": [ + "First, we should initialize distributed environment. Though we just use single GPU in this example, we still need initialize distributed environment for compatibility. We just consider the simplest case here, so we just set the number of parallel processes to 1." + ] + }, + { + "cell_type": "code", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "8yF7Lc-K7NAS", + "outputId": "01312349-a8b0-4de4-9103-7d1b48e6cc36" + }, + "source": [ + "parallel_cfg = Config(dict(parallel=dict(\n", + " data=dict(size=1),\n", + " pipeline=dict(size=1),\n", + " tensor=dict(size=1, mode=None),\n", + ")))\n", + "colossalai.init_dist(config=parallel_cfg,\n", + " local_rank=0,\n", + " world_size=1,\n", + " host='127.0.0.1',\n", + " port=8888,\n", + " backend='nccl')" + ], + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "name": "stderr", + "text": [ + "colossalai - torch.distributed.distributed_c10d - 2021-10-15 03:27:51,596 INFO: Added key: store_based_barrier_key:1 to store for rank: 0\n", + "colossalai - torch.distributed.distributed_c10d - 2021-10-15 03:27:51,598 INFO: Rank 0: Completed store-based barrier for 1 nodes.\n", + "colossalai - torch.distributed.distributed_c10d - 2021-10-15 03:27:51,602 INFO: Added key: store_based_barrier_key:2 to store for rank: 0\n", + "colossalai - torch.distributed.distributed_c10d - 2021-10-15 03:27:51,605 INFO: Rank 0: Completed store-based barrier for 1 nodes.\n", + "colossalai - torch.distributed.distributed_c10d - 2021-10-15 03:27:51,608 INFO: Added key: store_based_barrier_key:3 to store for rank: 0\n", + "colossalai - torch.distributed.distributed_c10d - 2021-10-15 03:27:51,610 INFO: Rank 0: Completed store-based barrier for 1 nodes.\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "process rank 0 is bound to device 0\n", + "initialized seed on rank 0, numpy: 1024, python random: 1024, ParallelMode.DATA: 1024, ParallelMode.TENSOR: 1124,the default parallel seed is ParallelMode.DATA.\n" + ] + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ppjmMxc_81TK" + }, + "source": [ + "Load and normalize the CIFAR10 training and test datasets using `colossalai.nn.data`. Note that we have wrapped `torchvision.transforms`, so that we can simply use the config dict to use them." + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "ZyGhyD47-dUY", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "98bbf2d1-a1c4-4bb4-b6df-600777b1e8f5" + }, + "source": [ + "transform_cfg = [\n", + " dict(type='ToTensor'),\n", + " dict(type='Normalize',\n", + " mean=[0.4914, 0.4822, 0.4465],\n", + " std=[0.2023, 0.1994, 0.2010]),\n", + "]\n", + "\n", + "batch_size = 128\n", + "\n", + "trainset = colossalai.nn.data.CIFAR10Dataset(transform_cfg, root='./data', train=True)\n", + "trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2)\n", + "\n", + "testset = colossalai.nn.data.CIFAR10Dataset(transform_cfg, root='./data', train=False)\n", + "testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=2)" + ], + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Files already downloaded and verified\n", + "Files already downloaded and verified\n" + ] + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "NvPbfLLR9NzC" + }, + "source": [ + "We just define a simple Convolutional Neural Network here." + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "cQ_y7lBG09LS" + }, + "source": [ + "import torch.nn as nn\n", + "import torch.nn.functional as F\n", + "\n", + "\n", + "class Net(nn.Module):\n", + " def __init__(self):\n", + " super().__init__()\n", + " self.conv1 = nn.Conv2d(3, 6, 5)\n", + " self.pool = nn.MaxPool2d(2, 2)\n", + " self.conv2 = nn.Conv2d(6, 16, 5)\n", + " self.fc1 = nn.Linear(16 * 5 * 5, 120)\n", + " self.fc2 = nn.Linear(120, 84)\n", + " self.fc3 = nn.Linear(84, 10)\n", + "\n", + " def forward(self, x):\n", + " x = self.pool(F.relu(self.conv1(x)))\n", + " x = self.pool(F.relu(self.conv2(x)))\n", + " x = torch.flatten(x, 1) # flatten all dimensions except batch\n", + " x = F.relu(self.fc1(x))\n", + " x = F.relu(self.fc2(x))\n", + " x = self.fc3(x)\n", + " return x\n", + "\n", + "\n", + "model = Net().cuda()" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "tgsszAmM9dYZ" + }, + "source": [ + "Define a Loss function and optimizer. And then we use them to initialize `Engine` and `Trainer`. We provide various training / evaluating hooks. In this case, we just use the simplest hooks which can compute and print loss and accuracy." + ] + }, + { + "cell_type": "code", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "YtaDoCax1BCf", + "outputId": "b33b1641-03d8-4597-c8c2-1a4c1d61e9b0" + }, + "source": [ + "import torch.optim as optim\n", + "\n", + "criterion = nn.CrossEntropyLoss()\n", + "optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)\n", + "schedule = NoPipelineSchedule()\n", + "engine = Engine(\n", + " model=model,\n", + " criterion=criterion,\n", + " optimizer=optimizer,\n", + " lr_scheduler=None,\n", + " schedule=schedule\n", + " )\n", + "trainer = Trainer(engine=engine,\n", + " hooks_cfg=[dict(type='LossHook'), dict(type='LogMetricByEpochHook'), dict(type='AccuracyHook')],\n", + " verbose=True)" + ], + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "name": "stderr", + "text": [ + "colossalai - rank_0 - 2021-10-15 03:27:56,018 WARNING: No gradient handler is set up, please make sure you do not need to all-reduce the gradients after a training step.\n", + "colossalai - rank_0 - 2021-10-15 03:27:56,024 INFO: build LogMetricByEpochHook for train, priority = 1\n", + "colossalai - rank_0 - 2021-10-15 03:27:56,026 INFO: build LossHook for train, priority = 10\n", + "colossalai - rank_0 - 2021-10-15 03:27:56,029 INFO: build AccuracyHook for train, priority = 10\n" + ] + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "_JR2TuvH99Ik" + }, + "source": [ + "Then we set training configs. We train our model for 10 epochs and it will be evaluated every 1 epoch. Set `display_progress` to `True` to display the training / evaluating progress bar." + ] + }, + { + "cell_type": "code", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "w-J3IP-J1sfx", + "outputId": "bdb76939-04f1-4124-ce5e-3af44c0d902c" + }, + "source": [ + "num_epochs = 10\n", + "test_interval = 1\n", + "trainer.fit(\n", + " train_dataloader=trainloader,\n", + " test_dataloader=testloader,\n", + " max_epochs=num_epochs,\n", + " display_progress=True,\n", + " test_interval=test_interval\n", + " )" + ], + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "name": "stderr", + "text": [ + "[Epoch 0 train]: 0%| | 0/391 [00:00=0.9 numpy tqdm psutil -tensorboardX +tensorboard packaging \ No newline at end of file diff --git a/setup.py b/setup.py index d71876bb9..8541b0a6c 100644 --- a/setup.py +++ b/setup.py @@ -121,7 +121,7 @@ if "--cuda_ext" in sys.argv: install_requires = fetch_requirements('requirements/requirements.txt') setup( - name='colossal-ai', + name='colossalai', version='0.0.1-beta', packages=find_packages(exclude=('csrc', 'tests', diff --git a/tests/test_data_pipeline_tensor_parallel/configs/vit_2d.py b/tests/test_data_pipeline_tensor_parallel/configs/vit_2d.py index 907605317..c97ed1804 100644 --- a/tests/test_data_pipeline_tensor_parallel/configs/vit_2d.py +++ b/tests/test_data_pipeline_tensor_parallel/configs/vit_2d.py @@ -27,8 +27,6 @@ train_data = dict( dataloader=dict( batch_size=BATCH_SIZE, pin_memory=True, - # num_workers=1, - # shuffle=True, ) ) @@ -63,14 +61,6 @@ loss = dict( type='CrossEntropyLoss2D', ) -# model = dict( -# type='VanillaResNet', -# block_type='ResNetBasicBlock', -# layers=[2, 2, 2, 2], -# num_cls=10 -# ) - - model = dict( type='VisionTransformerFromConfig', tensor_splitting_cfg=dict( @@ -135,25 +125,26 @@ parallel = dict( fp16 = dict( mode=AMP_TYPE.PARALLEL, - initial_scale=2 ** 8 ) -# fp16 = dict( -# mode=None, -# ) - -schedule = dict( - num_microbatches=2 -) -lr_scheduler = dict( - type='LinearWarmupLR', - warmup_epochs=5 +engine = dict( + schedule=dict( + num_microbatches=2 + ) ) +hooks = [ + dict( + type='LRSchedulerHook', + by_epoch=True, + lr_scheduler_cfg=dict( + type='LinearWarmupLR', + warmup_steps=5 + ) + ), +] num_epochs = 60 logging = dict( root_path='test_vit_2d_log' ) - -seed = 100 diff --git a/tests/test_data_pipeline_tensor_parallel/configs/vit_2p5d.py b/tests/test_data_pipeline_tensor_parallel/configs/vit_2p5d.py index d41ecea89..fd9c89eb4 100644 --- a/tests/test_data_pipeline_tensor_parallel/configs/vit_2p5d.py +++ b/tests/test_data_pipeline_tensor_parallel/configs/vit_2p5d.py @@ -124,14 +124,21 @@ parallel = dict( tensor=dict(size=4, depth=1, mode='2.5d'), ) -lr_scheduler = dict( - type='LinearWarmupLR', - warmup_epochs=5 -) +hooks = [ + dict( + type='LRSchedulerHook', + by_epoch=True, + lr_scheduler_cfg=dict( + type='LinearWarmupLR', + warmup_steps=5 + ) + ), +] +engine = dict( schedule = dict( num_microbatches=2 ) +) num_epochs = 60 -num_microbatches = 1 diff --git a/tests/test_data_pipeline_tensor_parallel/test_vit_2d/test_vit_2d.py b/tests/test_data_pipeline_tensor_parallel/test_vit_2d/test_vit_2d.py index 9ffd0a1ec..b68a58cea 100644 --- a/tests/test_data_pipeline_tensor_parallel/test_vit_2d/test_vit_2d.py +++ b/tests/test_data_pipeline_tensor_parallel/test_vit_2d/test_vit_2d.py @@ -9,21 +9,22 @@ import torch.autograd import colossalai from colossalai.context.parallel_mode import ParallelMode from colossalai.core import global_context as gpc -from colossalai.engine import Engine from colossalai.logging import get_global_dist_logger from colossalai.nn.layer._parallel_utilities import _gather CONFIG_PATH = Path(__file__).parent.parent.joinpath('configs/vit_2d.py') -def eval(engine): +def eval(engine, test_dataloader): engine.eval() accumulated_loss = 0 correct_sum = 0 total_sum = 0 + num_steps = len(test_dataloader) + data_iter = iter(test_dataloader) - for i in range(engine.schedule.num_steps): - output, label, loss = engine.step() + for i in range(num_steps): + output, label, loss = engine.step(data_iter) if gpc.is_last_rank(ParallelMode.PIPELINE): # loss = sum(loss) @@ -43,20 +44,22 @@ def eval(engine): correct = torch.sum(label == output) correct_sum += correct total_sum += label.size(0) - avg_loss = accumulated_loss / engine.schedule.num_steps + avg_loss = accumulated_loss / num_steps return correct_sum, total_sum, avg_loss -def train(engine): +def train(engine, train_dataloader): engine.train() accumulated_loss = 0 + num_steps = len(train_dataloader) + data_iter = iter(train_dataloader) - for i in range(engine.schedule.num_steps): - output, label, loss = engine.step() + for i in range(num_steps): + output, label, loss = engine.step(data_iter) if gpc.is_last_rank(ParallelMode.PIPELINE): accumulated_loss += loss.detach().cpu().numpy() - avg_loss = accumulated_loss / engine.schedule.num_steps + avg_loss = accumulated_loss / num_steps return avg_loss @@ -64,25 +67,16 @@ def train(engine): @pytest.mark.skip("This test should be invoked by test.sh in the same folder as it runs on multiple gpus") def test_2d_parallel_vision_transformer(): # init dist - model, train_dataloader, test_dataloader, criterion, optimizer, schedule, lr_scheduler = colossalai.initialize( - CONFIG_PATH) + engine, train_dataloader, test_dataloader = colossalai.initialize(CONFIG_PATH) logger = get_global_dist_logger() - engine = Engine(model=model, - train_dataloader=train_dataloader, - test_dataloader=test_dataloader, - criterion=criterion, - optimizer=optimizer, - lr_scheduler=lr_scheduler, - schedule=schedule) - for epoch in range(gpc.config.num_epochs): - train_loss = train(engine) + train_loss = train(engine, train_dataloader) if gpc.is_last_rank(ParallelMode.PIPELINE): logger.info(f'epoch {epoch} - train loss: {train_loss}') if epoch % 2 == 0: - correct_sum, total_sum, eval_loss = eval(engine) + correct_sum, total_sum, eval_loss = eval(engine, test_dataloader) if gpc.is_last_rank(ParallelMode.PIPELINE): logger.info( f'epoch {epoch} - eval loss: {eval_loss}, total: {total_sum}, ' diff --git a/tests/test_data_pipeline_tensor_parallel/test_vit_2p5d/test_vit_2p5d.py b/tests/test_data_pipeline_tensor_parallel/test_vit_2p5d/test_vit_2p5d.py index 33d56360a..70857f1e8 100644 --- a/tests/test_data_pipeline_tensor_parallel/test_vit_2p5d/test_vit_2p5d.py +++ b/tests/test_data_pipeline_tensor_parallel/test_vit_2p5d/test_vit_2p5d.py @@ -6,20 +6,22 @@ import torch.autograd import colossalai from colossalai.context.parallel_mode import ParallelMode from colossalai.core import global_context as gpc -from colossalai.engine import Engine from colossalai.logging import get_global_dist_logger from colossalai.nn.layer._parallel_utilities import _gather CONFIG_PATH = Path(__file__).parent.parent.joinpath('configs/vit_2p5d.py') -def eval(engine): + +def eval(engine, test_dataloader): engine.eval() accumulated_loss = 0 correct_sum = 0 total_sum = 0 + num_steps = len(test_dataloader) + data_iter = iter(test_dataloader) - for i in range(engine.schedule.num_steps): - output, label, loss = engine.step() + for i in range(num_steps): + output, label, loss = engine.step(data_iter) if gpc.is_last_rank(ParallelMode.PIPELINE): accumulated_loss += loss.detach().cpu().numpy() @@ -43,21 +45,23 @@ def eval(engine): correct = torch.sum(label == output) correct_sum += correct total_sum += label.size(0) - avg_loss = accumulated_loss / engine.schedule.num_steps + avg_loss = accumulated_loss / num_steps return correct_sum, total_sum, avg_loss -def train(engine): +def train(engine, train_dataloader): engine.train() accumulated_loss = 0 + num_steps = len(train_dataloader) + data_iter = iter(train_dataloader) + + for i in range(num_steps): + output, label, loss = engine.step(data_iter) - for i in range(engine.schedule.num_steps): - output, label, loss = engine.step() - if gpc.is_last_rank(ParallelMode.PIPELINE): accumulated_loss += loss.detach().cpu().numpy() - avg_loss = accumulated_loss / engine.schedule.num_steps + avg_loss = accumulated_loss / num_steps return avg_loss @@ -65,25 +69,16 @@ def train(engine): @pytest.mark.skip("This test should be invoked by test.sh in the same folder as it runs on multiple gpus") def test_2p5d_parallel_vision_transformer(): # init dist - model, train_dataloader, test_dataloader, criterion, optimizer, schedule, lr_scheduler = colossalai.initialize( - CONFIG_PATH) + engine, train_dataloader, test_dataloader = colossalai.initialize(CONFIG_PATH) logger = get_global_dist_logger() - engine = Engine(model=model, - train_dataloader=train_dataloader, - test_dataloader=test_dataloader, - criterion=criterion, - optimizer=optimizer, - lr_scheduler=lr_scheduler, - schedule=schedule) - for epoch in range(gpc.config.num_epochs): - train_loss = train(engine) + train_loss = train(engine, train_dataloader) if gpc.is_last_rank(ParallelMode.PIPELINE): logger.info(f'epoch {epoch} - train loss: {train_loss}') if epoch % 2 == 0: - correct_sum, total_sum, eval_loss = eval(engine) + correct_sum, total_sum, eval_loss = eval(engine, test_dataloader) if gpc.is_last_rank(ParallelMode.PIPELINE): logger.info( f'epoch {epoch} - eval loss: {eval_loss}, total: {total_sum}, ' @@ -91,4 +86,4 @@ def test_2p5d_parallel_vision_transformer(): if __name__ == '__main__': - test_2p5d_parallel_vision_transformer() \ No newline at end of file + test_2p5d_parallel_vision_transformer() diff --git a/tests/test_engine/configs/non_pipeline_resnet.py b/tests/test_engine/configs/non_pipeline_resnet.py index de78154ec..19f2d61d2 100644 --- a/tests/test_engine/configs/non_pipeline_resnet.py +++ b/tests/test_engine/configs/non_pipeline_resnet.py @@ -38,5 +38,3 @@ optimizer = dict(type='Adam', lr=0.001) loss = dict(type='CrossEntropyLoss') -# set_device_func = lambda global_rank, world_size: global_rank % 4 -seed = 1024 diff --git a/tests/test_engine/configs/non_pipeline_resnet_apex_amp.py b/tests/test_engine/configs/non_pipeline_resnet_apex_amp.py index b6300b8c4..f845d9842 100644 --- a/tests/test_engine/configs/non_pipeline_resnet_apex_amp.py +++ b/tests/test_engine/configs/non_pipeline_resnet_apex_amp.py @@ -40,6 +40,3 @@ optimizer = dict(type='Adam', lr=0.001) loss = dict(type='CrossEntropyLoss') fp16 = dict(mode=AMP_TYPE.APEX) - -# set_device_func = lambda global_rank, world_size: global_rank % 4 -seed = 1024 diff --git a/tests/test_engine/configs/non_pipeline_resnet_torch_amp.py b/tests/test_engine/configs/non_pipeline_resnet_torch_amp.py index 87fd68554..ab4517e92 100644 --- a/tests/test_engine/configs/non_pipeline_resnet_torch_amp.py +++ b/tests/test_engine/configs/non_pipeline_resnet_torch_amp.py @@ -40,6 +40,3 @@ optimizer = dict(type='Adam', lr=0.001) loss = dict(type='CrossEntropyLoss') fp16 = dict(mode=AMP_TYPE.TORCH) - -# set_device_func = lambda global_rank, world_size: global_rank % 4 -seed = 1024 diff --git a/tests/test_engine/configs/pipeline_vanilla_resnet.py b/tests/test_engine/configs/pipeline_vanilla_resnet.py index 9820d3b82..a47f40613 100644 --- a/tests/test_engine/configs/pipeline_vanilla_resnet.py +++ b/tests/test_engine/configs/pipeline_vanilla_resnet.py @@ -38,11 +38,9 @@ parallel = dict( tensor=dict(size=1, mode=None) ) -schedule = dict( - num_microbatches=4 +engine = dict( + schedule=dict( + num_microbatches=4 + ) ) -num_pipeling_batches = 2 -seed = 1024 -lr_scheduler = dict(type='LinearWarmupLR', warmup_steps=5) - num_epochs = 10 diff --git a/tests/test_engine/test_non_pipeline_engine/test_engine_apex_amp.py b/tests/test_engine/test_non_pipeline_engine/test_engine_apex_amp.py index fe6b4010b..98c2b8072 100644 --- a/tests/test_engine/test_non_pipeline_engine/test_engine_apex_amp.py +++ b/tests/test_engine/test_non_pipeline_engine/test_engine_apex_amp.py @@ -8,7 +8,6 @@ import torch from colossalai import initialize from colossalai.core import global_context as gpc -from colossalai.engine import Engine from colossalai.logging import get_global_dist_logger from colossalai.utils import report_memory_usage @@ -24,20 +23,13 @@ NO_PIPE_CONFIG_PATH = osp.join(DIR_PATH, '../configs/non_pipeline_resnet_apex_am def run_no_pipeline(config): - model, train_dataloader, test_dataloader, criterion, optimizer, schedule, lr_scheduler = initialize(config) + engine, train_dataloader, test_dataloader = initialize(config) logger = get_global_dist_logger() rank = torch.distributed.get_rank() - engine = Engine(model=model, - train_dataloader=train_dataloader, - criterion=criterion, - optimizer=optimizer, - schedule=schedule) engine.train() - logger.info('lr = %g' % engine.get_lr()) - output, label, loss = engine.step() + output, label, loss = engine.step(iter(train_dataloader)) logger.info('Rank {} returns: {}'.format(rank, loss.item())) - logger.info('lr = %g' % engine.get_lr()) gpc.destroy() logger.info('Test engine finished') diff --git a/tests/test_engine/test_non_pipeline_engine/test_engine_no_amp.py b/tests/test_engine/test_non_pipeline_engine/test_engine_no_amp.py index 865f2b04e..effb65e02 100644 --- a/tests/test_engine/test_non_pipeline_engine/test_engine_no_amp.py +++ b/tests/test_engine/test_non_pipeline_engine/test_engine_no_amp.py @@ -8,7 +8,6 @@ import torch from colossalai import initialize from colossalai.core import global_context as gpc -from colossalai.engine import Engine from colossalai.logging import get_global_dist_logger from colossalai.utils import report_memory_usage @@ -26,21 +25,14 @@ NO_PIPE_CONFIG_PATH = osp.join(DIR_PATH, '../configs/non_pipeline_resnet.py') def test_no_pipeline(config): print('Test no pipeline engine start') - model, train_dataloader, test_dataloader, criterion, optimizer, schedule, lr_scheduler = initialize(config) + engine, train_dataloader, test_dataloader = initialize(config) logger = get_global_dist_logger() rank = torch.distributed.get_rank() - engine = Engine(model=model, - train_dataloader=train_dataloader, - criterion=criterion, - optimizer=optimizer, - schedule=schedule) engine.train() - logger.info('lr = %g' % engine.get_lr()) - output, label, loss = engine.step() + output, label, loss = engine.step(iter(train_dataloader)) logger.info('Rank {} returns: {}'.format(rank, loss.item())) - logger.info('lr = %g' % engine.get_lr()) gpc.destroy() logger.info('Test engine finished') diff --git a/tests/test_engine/test_non_pipeline_engine/test_engine_torch_amp.py b/tests/test_engine/test_non_pipeline_engine/test_engine_torch_amp.py index 83c6927f3..a4c496a7d 100644 --- a/tests/test_engine/test_non_pipeline_engine/test_engine_torch_amp.py +++ b/tests/test_engine/test_non_pipeline_engine/test_engine_torch_amp.py @@ -8,7 +8,6 @@ import torch from colossalai import initialize from colossalai.core import global_context as gpc -from colossalai.engine import Engine from colossalai.logging import get_global_dist_logger from colossalai.utils import report_memory_usage @@ -26,21 +25,13 @@ NO_PIPE_CONFIG_PATH = osp.join(DIR_PATH, '../configs/non_pipeline_resnet_torch_a def test_no_pipeline(config): print('Test no pipeline engine start') - model, train_dataloader, test_dataloader, criterion, optimizer, schedule, lr_scheduler = initialize(config) + engine, train_dataloader, test_dataloader = initialize(config) logger = get_global_dist_logger() - rank = torch.distributed.get_rank() - engine = Engine(model=model, - train_dataloader=train_dataloader, - criterion=criterion, - optimizer=optimizer, - schedule=schedule) engine.train() - logger.info('lr = %g' % engine.get_lr()) - output, label, loss = engine.step() + output, label, loss = engine.step(iter(train_dataloader)) logger.info('Rank {} returns: {}'.format(rank, loss.item())) - logger.info('lr = %g' % engine.get_lr()) gpc.destroy() logger.info('Test engine finished') diff --git a/tests/test_engine/test_pipeline/test_schedule.py b/tests/test_engine/test_pipeline/test_schedule.py index 32fcaafc1..9125fb3ee 100644 --- a/tests/test_engine/test_pipeline/test_schedule.py +++ b/tests/test_engine/test_pipeline/test_schedule.py @@ -5,6 +5,7 @@ import os.path as osp import pytest +from colossalai.context import ParallelMode from colossalai.core import global_context as gpc from colossalai.initialize import initialize from colossalai.logging import get_global_dist_logger @@ -22,13 +23,25 @@ CONFIG_PATH = osp.join(DIR_PATH, '../configs/pipeline_vanilla_resnet.py') @pytest.mark.skip("This test should be invoked using the test.sh provided") @pytest.mark.dist def test_schedule(): - model, train_dataloader, test_dataloader, criterion, optimizer, schedule, lr_scheduler = initialize(CONFIG_PATH) + engine, train_dataloader, test_dataloader = initialize(CONFIG_PATH) logger = get_global_dist_logger() - schedule.zero_grad() - output, label, losses = schedule.forward_backward_step(forward_only=False) - schedule.step() - logger.info('losses: {}'.format([loss.item() for loss in losses])) + model = engine.model + optimizer = engine.optimizer + criterion = engine.criterion + schedule = engine._schedule + + output, label, loss = schedule.forward_backward_step( + data_iter=iter(train_dataloader), + model=model, + optimizer=optimizer, + criterion=criterion, + forward_only=False + ) + schedule.optimizer_step(model, optimizer) + + if gpc.is_last_rank(ParallelMode.PIPELINE): + logger.info('losses: {}'.format(loss)) gpc.destroy() logger.info('training finished') diff --git a/tests/test_engine/test_pipeline_engine/test_engine.py b/tests/test_engine/test_pipeline_engine/test_engine.py index 7ed0b0a3d..9d6c9f59f 100644 --- a/tests/test_engine/test_pipeline_engine/test_engine.py +++ b/tests/test_engine/test_pipeline_engine/test_engine.py @@ -9,7 +9,6 @@ import torch from colossalai import initialize from colossalai.context import ParallelMode from colossalai.core import global_context as gpc -from colossalai.engine import Engine from colossalai.logging import get_global_dist_logger NUM_BATCH = 128 @@ -23,22 +22,14 @@ PIPE_CONFIG_PATH = osp.join(DIR_PATH, '../configs/pipeline_vanilla_resnet.py') def run_pipeline(config): - model, train_dataloader, test_dataloader, criterion, optimizer, schedule, lr_scheduler = initialize(config) + engine, train_dataloader, test_dataloader = initialize(config) logger = get_global_dist_logger() rank = torch.distributed.get_rank() - engine = Engine(model=model, - train_dataloader=train_dataloader, - criterion=criterion, - optimizer=optimizer, - lr_scheduler=lr_scheduler, - schedule=schedule) engine.train() - logger.info('lr = %g' % engine.get_lr()) - outputs, labels, loss = engine.step() + outputs, labels, loss = engine.step(iter(train_dataloader)) if gpc.is_last_rank(ParallelMode.PIPELINE): logger.info('losses: {}'.format(rank, loss.item())) - logger.info('lr = %g' % engine.get_lr()) gpc.destroy() logger.info('Test engine pipeline finished') diff --git a/tests/test_fp16_optimizer/configs/vit_2d.py b/tests/test_fp16_optimizer/configs/vit_2d.py index bcef5e2d4..6283dea9b 100644 --- a/tests/test_fp16_optimizer/configs/vit_2d.py +++ b/tests/test_fp16_optimizer/configs/vit_2d.py @@ -132,9 +132,12 @@ fp16 = dict( initial_scale=2 ** 4 ) +num_epochs = 60 + + lr_scheduler = dict( type='LinearWarmupLR', - warmup_epochs=5 + warmup_steps=5, + total_steps=num_epochs ) -num_epochs = 60 diff --git a/tests/test_fp16_optimizer/test_vit_2d/test_vit_2d.py b/tests/test_fp16_optimizer/test_vit_2d/test_vit_2d.py index a02ede90c..45c36f384 100644 --- a/tests/test_fp16_optimizer/test_vit_2d/test_vit_2d.py +++ b/tests/test_fp16_optimizer/test_vit_2d/test_vit_2d.py @@ -7,23 +7,25 @@ import pytest import torch.autograd import colossalai +from colossalai.builder import build_lr_scheduler from colossalai.context.parallel_mode import ParallelMode from colossalai.core import global_context as gpc -from colossalai.engine import Engine from colossalai.logging import get_global_dist_logger from colossalai.nn.layer._parallel_utilities import _gather CONFIG_PATH = Path(__file__).parent.parent.joinpath('configs/vit_2d.py') -def eval(engine): +def eval(engine, test_dataloader): engine.eval() accumulated_loss = 0 correct_sum = 0 total_sum = 0 + num_steps = len(test_dataloader) + data_iter = iter(test_dataloader) - for i in range(engine.schedule.num_steps): - output, label, loss = engine.step() + for i in range(num_steps): + output, label, loss = engine.step(data_iter) accumulated_loss += loss.detach().cpu().numpy() output = _gather( @@ -40,18 +42,21 @@ def eval(engine): correct = torch.sum(label[0] == output) correct_sum += correct total_sum += label[0].size(0) - avg_loss = accumulated_loss / engine.schedule.num_steps + avg_loss = accumulated_loss / num_steps return correct_sum, total_sum, avg_loss -def train(engine): +def train(engine, train_dataloader, lr_scheduler): engine.train() accumulated_loss = 0 + num_steps = len(train_dataloader) + data_iter = iter(train_dataloader) - for i in range(engine.schedule.num_steps): - output, label, loss = engine.step() + for i in range(num_steps): + output, label, loss = engine.step(data_iter) accumulated_loss += loss.squeeze(0).detach().cpu().numpy() - avg_loss = accumulated_loss / engine.schedule.num_steps + avg_loss = accumulated_loss / num_steps + lr_scheduler.step() return avg_loss @@ -59,26 +64,18 @@ def train(engine): @pytest.mark.skip("This test should be invoked by test.sh in the same folder as it runs on multiple gpus") def test_2d_parallel_vision_transformer(): # init dist - model, train_dataloader, test_dataloader, criterion, optimizer, schedule, lr_scheduler = colossalai.initialize( - CONFIG_PATH) + engine, train_dataloader, test_dataloader = colossalai.initialize(CONFIG_PATH) + lr_scheduler = build_lr_scheduler(gpc.config.lr_scheduler, engine.optimizer) logger = get_global_dist_logger() - engine = Engine(model=model, - train_dataloader=train_dataloader, - test_dataloader=test_dataloader, - criterion=criterion, - optimizer=optimizer, - lr_scheduler=lr_scheduler, - schedule=schedule) - logger.info('start training') for epoch in range(gpc.config.num_epochs): - train_loss = train(engine) + train_loss = train(engine, train_dataloader, lr_scheduler) logger.info(f'epoch {epoch} - train loss: {train_loss}') if epoch % 2 == 0: - correct_sum, total_sum, eval_loss = eval(engine) + correct_sum, total_sum, eval_loss = eval(engine, test_dataloader) logger.info( f'epoch {epoch} - eval loss: {eval_loss}, total: {total_sum}, ' f'correct: {correct_sum}, acc: {correct_sum / total_sum}') diff --git a/tests/test_models/test_vision_transformer/configs/vit_2d.py b/tests/test_models/test_vision_transformer/configs/vit_2d.py index 92706e8cd..1fd1102fb 100644 --- a/tests/test_models/test_vision_transformer/configs/vit_2d.py +++ b/tests/test_models/test_vision_transformer/configs/vit_2d.py @@ -102,6 +102,6 @@ parallel = dict( tensor=dict(size=4, mode='2d'), ) -lr_scheduler = dict(type='LinearWarmupLR', warmup_epochs=5) - num_epochs = 60 + +lr_scheduler = dict(type='LinearWarmupLR', warmup_steps=5, total_steps=num_epochs) diff --git a/tests/test_models/test_vision_transformer/configs/vit_2p5d.py b/tests/test_models/test_vision_transformer/configs/vit_2p5d.py index f788cb704..3c16d684a 100644 --- a/tests/test_models/test_vision_transformer/configs/vit_2p5d.py +++ b/tests/test_models/test_vision_transformer/configs/vit_2p5d.py @@ -125,13 +125,6 @@ parallel = dict( tensor=dict(size=4, depth=1, mode='2.5d'), ) -lr_scheduler = dict( - type='LinearWarmupLR', - warmup_epochs=5 -) - -schedule = dict( - num_microbatches=8 -) - num_epochs = 60 + +lr_scheduler = dict(type='LinearWarmupLR', warmup_steps=5, total_steps=num_epochs) diff --git a/tests/test_models/test_vision_transformer/configs/vit_3d.py b/tests/test_models/test_vision_transformer/configs/vit_3d.py index c66212f04..ad041efd0 100644 --- a/tests/test_models/test_vision_transformer/configs/vit_3d.py +++ b/tests/test_models/test_vision_transformer/configs/vit_3d.py @@ -116,9 +116,14 @@ hooks = [ weight_parallel_mode=ParallelMode.PARALLEL_3D_WEIGHT, ), dict(type='LossHook'), - # dict(type='TensorboardHook', log_dir='./tfb_logs'), - # dict(type='SaveCheckpointHook', interval=5, checkpoint_dir='./ckpt'), - # dict(type='LoadCheckpointHook', epoch=20, checkpoint_dir='./ckpt') + dict( + type='LRSchedulerHook', + by_epoch=True, + lr_scheduler_cfg=dict( + type='LinearWarmupLR', + warmup_steps=5 + ) + ), ] parallel = dict( @@ -127,12 +132,4 @@ parallel = dict( tensor=dict(mode='3d', size=8), ) -# fp16 = dict(mode=AMP_TYPE.PARALLEL, initial_scale=2 ** 6) - -lr_scheduler = dict(type='LinearWarmupLR', warmup_epochs=5) - -# schedule = dict(num_microbatches=4) - num_epochs = 60 - -seed = 42 diff --git a/tests/test_models/test_vision_transformer/test_vit_2d/test_vit_2d.py b/tests/test_models/test_vision_transformer/test_vit_2d/test_vit_2d.py index fb32bea49..487ba335b 100644 --- a/tests/test_models/test_vision_transformer/test_vit_2d/test_vit_2d.py +++ b/tests/test_models/test_vision_transformer/test_vit_2d/test_vit_2d.py @@ -7,23 +7,25 @@ import pytest import torch.autograd import colossalai +from colossalai.builder import build_lr_scheduler from colossalai.context.parallel_mode import ParallelMode from colossalai.core import global_context as gpc -from colossalai.engine import Engine from colossalai.logging import get_global_dist_logger from colossalai.nn.layer._parallel_utilities import _gather CONFIG_PATH = Path(__file__).parent.parent.joinpath('configs/vit_2d.py') -def eval(engine): +def eval(engine, test_dataloader): engine.eval() accumulated_loss = 0 correct_sum = 0 total_sum = 0 + num_steps = len(test_dataloader) + data_iter = iter(test_dataloader) - for i in range(engine.schedule.num_steps): - output, label, loss = engine.step() + for i in range(num_steps): + output, label, loss = engine.step(data_iter) accumulated_loss += loss.detach().cpu().numpy() output = _gather( @@ -40,18 +42,21 @@ def eval(engine): correct = torch.sum(label[0] == output) correct_sum += correct total_sum += label[0].size(0) - avg_loss = accumulated_loss / engine.schedule.num_steps + avg_loss = accumulated_loss / num_steps return correct_sum, total_sum, avg_loss -def train(engine): +def train(engine, train_dataloader, lr_scheduler): engine.train() accumulated_loss = 0 + num_steps = len(train_dataloader) + data_iter = iter(train_dataloader) - for i in range(engine.schedule.num_steps): - output, label, loss = engine.step() + for i in range(num_steps): + output, label, loss = engine.step(data_iter) accumulated_loss += loss.detach().cpu().numpy() - avg_loss = accumulated_loss / engine.schedule.num_steps + avg_loss = accumulated_loss / num_steps + lr_scheduler.step() return avg_loss @@ -59,25 +64,17 @@ def train(engine): @pytest.mark.skip("This test should be invoked by test.sh in the same folder as it runs on multiple gpus") def test_2d_parallel_vision_transformer(): # init dist - model, train_dataloader, test_dataloader, criterion, optimizer, schedule, lr_scheduler = colossalai.initialize( - CONFIG_PATH) + engine, train_dataloader, test_dataloader = colossalai.initialize(CONFIG_PATH) + lr_scheduler = build_lr_scheduler(gpc.config.lr_scheduler, engine.optimizer) logger = get_global_dist_logger() - engine = Engine(model=model, - train_dataloader=train_dataloader, - test_dataloader=test_dataloader, - criterion=criterion, - optimizer=optimizer, - lr_scheduler=lr_scheduler, - schedule=schedule) - logger.info('start training') for epoch in range(gpc.config.num_epochs): - train_loss = train(engine) + train_loss = train(engine, train_dataloader, lr_scheduler) logger.info(f'epoch {epoch} - train loss: {train_loss}') if epoch % 2 == 0: - correct_sum, total_sum, eval_loss = eval(engine) + correct_sum, total_sum, eval_loss = eval(engine, test_dataloader) logger.info( f'epoch {epoch} - eval loss: {eval_loss}, total: {total_sum}, ' f'correct: {correct_sum}, acc: {correct_sum / total_sum}') diff --git a/tests/test_models/test_vision_transformer/test_vit_2p5d/test_vit_2p5d.py b/tests/test_models/test_vision_transformer/test_vit_2p5d/test_vit_2p5d.py index 1a576d039..a8361d2e6 100644 --- a/tests/test_models/test_vision_transformer/test_vit_2p5d/test_vit_2p5d.py +++ b/tests/test_models/test_vision_transformer/test_vit_2p5d/test_vit_2p5d.py @@ -4,22 +4,25 @@ import pytest import torch.autograd import colossalai +from colossalai.builder import build_lr_scheduler from colossalai.context.parallel_mode import ParallelMode from colossalai.core import global_context as gpc -from colossalai.engine import Engine from colossalai.logging import get_global_dist_logger from colossalai.nn.layer._parallel_utilities import _gather CONFIG_PATH = Path(__file__).parent.parent.joinpath('configs/vit_2p5d.py') -def eval(engine): + +def eval(engine, test_dataloader): engine.eval() accumulated_loss = 0 correct_sum = 0 total_sum = 0 + num_steps = len(test_dataloader) + data_iter = iter(test_dataloader) - for i in range(engine.schedule.num_steps): - output, label, loss = engine.step() + for i in range(num_steps): + output, label, loss = engine.step(data_iter) accumulated_loss += loss.detach().cpu().numpy() output = _gather( @@ -41,18 +44,21 @@ def eval(engine): correct = torch.sum(label[0] == output) correct_sum += correct total_sum += label[0].size(0) - avg_loss = accumulated_loss / engine.schedule.num_steps + avg_loss = accumulated_loss / num_steps return correct_sum, total_sum, avg_loss -def train(engine): +def train(engine, train_dataloader, lr_scheduler): engine.train() accumulated_loss = 0 + num_steps = len(train_dataloader) + data_iter = iter(train_dataloader) - for i in range(engine.schedule.num_steps): - output, label, loss = engine.step() + for i in range(num_steps): + output, label, loss = engine.step(data_iter) accumulated_loss += loss.detach().cpu().numpy() - avg_loss = accumulated_loss / engine.schedule.num_steps + avg_loss = accumulated_loss / num_steps + lr_scheduler.step() return avg_loss @@ -60,29 +66,21 @@ def train(engine): @pytest.mark.skip("This test should be invoked by test.sh in the same folder as it runs on multiple gpus") def test_2p5d_parallel_vision_transformer(): # init dist - model, train_dataloader, test_dataloader, criterion, optimizer, schedule, lr_scheduler = colossalai.initialize( - CONFIG_PATH) + engine, train_dataloader, test_dataloader = colossalai.initialize(CONFIG_PATH) + lr_scheduler = build_lr_scheduler(gpc.config.lr_scheduler, engine.optimizer) logger = get_global_dist_logger() - engine = Engine(model=model, - train_dataloader=train_dataloader, - test_dataloader=test_dataloader, - criterion=criterion, - optimizer=optimizer, - lr_scheduler=lr_scheduler, - schedule=schedule) - logger.info('start training') for epoch in range(gpc.config.num_epochs): - train_loss = train(engine) + train_loss = train(engine, train_dataloader, lr_scheduler) logger.info(f'epoch {epoch} - train loss: {train_loss}') if epoch % 2 == 0: - correct_sum, total_sum, eval_loss = eval(engine) + correct_sum, total_sum, eval_loss = eval(engine, test_dataloader) logger.info( f'epoch {epoch} - eval loss: {eval_loss}, total: {total_sum}, ' f'correct: {correct_sum}, acc: {correct_sum / total_sum}') if __name__ == '__main__': - test_2p5d_parallel_vision_transformer() \ No newline at end of file + test_2p5d_parallel_vision_transformer() diff --git a/tests/test_models/test_vision_transformer/test_vit_3d/test_vit_3d.py b/tests/test_models/test_vision_transformer/test_vit_3d/test_vit_3d.py index db78e9967..7bee2c78b 100644 --- a/tests/test_models/test_vision_transformer/test_vit_3d/test_vit_3d.py +++ b/tests/test_models/test_vision_transformer/test_vit_3d/test_vit_3d.py @@ -1,16 +1,14 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- - import time from pathlib import Path import torch from tqdm import tqdm -from colossalai import initialize +import colossalai from colossalai.context import ParallelMode from colossalai.core import global_context as gpc -from colossalai.engine import Engine from colossalai.logging import get_global_dist_logger from colossalai.trainer import Trainer from colossalai.trainer.metric import Accuracy3D @@ -29,7 +27,7 @@ def _train_epoch(epoch, engine): num_samples = 0 now = time.time() epoch_start = now - progress = range(engine.schedule.num_steps) + progress = range(engine._schedule.num_steps) if gpc.get_global_rank() == 0: progress = tqdm(progress, desc='[Epoch %d]' % epoch, miniters=1) for step in progress: @@ -68,7 +66,7 @@ def _eval(epoch, engine): ParallelMode.PARALLEL_3D_WEIGHT) total = 0 with torch.no_grad(): - for _ in range(engine.schedule.num_steps): + for _ in range(engine._schedule.num_steps): outputs, targets, loss = engine.step() if isinstance(outputs, (list, tuple)): outputs = outputs[0] @@ -80,32 +78,25 @@ def _eval(epoch, engine): print_rank_0( '[Epoch %d] Evaluation loss: %.3f | Acc: %.3f%%' % - (epoch, eval_loss / engine.schedule.num_steps, + (epoch, eval_loss / engine._schedule.num_steps, acc.get_accumulated_value() * 100), logger) def train(): - model, train_dataloader, test_dataloader, criterion, \ - optimizer, schedule, lr_scheduler = initialize(CONFIG_PATH) - + # init dist + engine, train_dataloader, test_dataloader = colossalai.initialize(CONFIG_PATH) logger = get_global_dist_logger() - engine = Engine(model=model, - train_dataloader=train_dataloader, - test_dataloader=test_dataloader, - criterion=criterion, - optimizer=optimizer, - lr_scheduler=lr_scheduler, - schedule=schedule) logger.info("Engine is built", ranks=[0]) - trainer = Trainer(engine=engine, hooks_cfg=gpc.config.hooks, verbose=True) + trainer = Trainer(engine=engine, verbose=True) logger.info("Trainer is built", ranks=[0]) logger.info("Train start", ranks=[0]) trainer.fit(train_dataloader=train_dataloader, test_dataloader=test_dataloader, - max_epochs=gpc.config.num_epochs, + epochs=gpc.config.num_epochs, + hooks_cfg=gpc.config.hooks, display_progress=True, test_interval=1) diff --git a/tests/test_trainer/configs/test_trainer_resnet.py b/tests/test_trainer/configs/test_trainer_resnet.py index 8979f4b09..ff48d4e6c 100644 --- a/tests/test_trainer/configs/test_trainer_resnet.py +++ b/tests/test_trainer/configs/test_trainer_resnet.py @@ -3,6 +3,7 @@ from pathlib import Path BATCH_SIZE = 128 IMG_SIZE = 32 +num_epochs = 200 # resnet 50 model = dict( @@ -77,18 +78,14 @@ hooks = [ dict(type='AccuracyHook'), dict(type='LossHook'), dict(type='TensorboardHook', log_dir='./tfb_logs'), + dict( + type='LRSchedulerHook', + by_epoch=True, + lr_scheduler_cfg=dict( + type='CosineAnnealingLR', + warmup_steps=5 + ) + ), dict(type='SaveCheckpointHook', interval=5, checkpoint_dir='./ckpt'), - # dict(type='LoadCheckpointHook', epoch=20, checkpoint_dir='./ckpt') ] -# fp16 = dict( -# mode=AMP_TYPE.PARALLEL, -# initial_scale=1 -# ) - -lr_scheduler = dict( - type='CosineAnnealingLR', - T_max=200 -) - -num_epochs = 200 diff --git a/tests/test_trainer/configs/test_trainer_vit_2d.py b/tests/test_trainer/configs/test_trainer_vit_2d.py index 15c799039..1769f4afe 100644 --- a/tests/test_trainer/configs/test_trainer_vit_2d.py +++ b/tests/test_trainer/configs/test_trainer_vit_2d.py @@ -11,6 +11,7 @@ NUM_ATTENTION_HEADS = 8 SUMMA_DIM = 2 NUM_CLASSES = 10 DEPTH = 6 +num_epochs = 60 train_data = dict( dataset=dict(type='CIFAR10Dataset', @@ -52,13 +53,6 @@ optimizer = dict(type='Adam', lr=0.001, weight_decay=0) loss = dict(type='CrossEntropyLoss2D', ) -# model = dict( -# type='VanillaResNet', -# block_type='ResNetBasicBlock', -# layers=[2, 2, 2, 2], -# num_cls=10 -# ) - model = dict( type='VisionTransformerFromConfig', tensor_splitting_cfg=dict(type='ViTInputSplitter2D', ), @@ -114,8 +108,15 @@ hooks = [ dict(type='Accuracy2DHook'), dict(type='LossHook'), dict(type='TensorboardHook', log_dir='./tfb_logs'), + dict( + type='LRSchedulerHook', + by_epoch=True, + lr_scheduler_cfg=dict( + type='LinearWarmupLR', + warmup_steps=5 + ) + ), dict(type='SaveCheckpointHook', interval=5, checkpoint_dir='./ckpt'), - # dict(type='LoadCheckpointHook', epoch=20, checkpoint_dir='./ckpt') ] parallel = dict( @@ -125,11 +126,8 @@ parallel = dict( fp16 = dict(mode=AMP_TYPE.PARALLEL, initial_scale=2 ** 8) -lr_scheduler = dict(type='LinearWarmupLR', warmup_epochs=5) - -schedule = dict(num_microbatches=1) - -num_epochs = 60 -num_microbatches = 1 +engine = dict( + schedule=dict(num_microbatches=1) +) logging = dict(root_path='./logs') diff --git a/tests/test_trainer/test_trainer.py b/tests/test_trainer/test_trainer.py index 0c0a458b3..6a7681d00 100644 --- a/tests/test_trainer/test_trainer.py +++ b/tests/test_trainer/test_trainer.py @@ -1,25 +1,16 @@ import colossalai from colossalai.core import global_context as gpc -from colossalai.engine import Engine from colossalai.logging import get_global_dist_logger from colossalai.trainer import Trainer def test_trainer(): - model, train_dataloader, test_dataloader, criterion, optimizer, schedule, lr_scheduler = colossalai.initialize() + engine, train_dataloader, test_dataloader = colossalai.initialize() logger = get_global_dist_logger() - engine = Engine( - model=model, - criterion=criterion, - optimizer=optimizer, - lr_scheduler=lr_scheduler, - schedule=schedule - ) logger.info("engine is built", ranks=[0]) trainer = Trainer(engine=engine, - hooks_cfg=gpc.config.hooks, verbose=True) logger.info("trainer is built", ranks=[0]) @@ -27,7 +18,8 @@ def test_trainer(): trainer.fit( train_dataloader=train_dataloader, test_dataloader=test_dataloader, - max_epochs=gpc.config.num_epochs, + hooks_cfg=gpc.config.hooks, + epochs=gpc.config.num_epochs, display_progress=False, test_interval=5 ) diff --git a/tests/test_zero_tensor_parallel/test_vit_2d/test_vit_2d.py b/tests/test_zero_tensor_parallel/test_vit_2d/test_vit_2d.py index 6533b3a6d..5c78dfcc2 100644 --- a/tests/test_zero_tensor_parallel/test_vit_2d/test_vit_2d.py +++ b/tests/test_zero_tensor_parallel/test_vit_2d/test_vit_2d.py @@ -18,14 +18,16 @@ level = os.environ['LEVEL'] CONFIG_PATH = Path(__file__).parent.parent.joinpath(f'configs/vit_2d_zero{level}.py') -def eval(engine): +def eval_epoch(engine: Engine, test_dataloader): engine.eval() accumulated_loss = 0 correct_sum = 0 total_sum = 0 + num_steps = len(test_dataloader) + data_iter = iter(test_dataloader) - for i in range(engine.schedule.num_steps): - output, label, loss = engine.step() + for i in range(num_steps): + output, label, loss = engine.step(data_iter) accumulated_loss += loss.detach().cpu().numpy() output = _gather( @@ -42,18 +44,19 @@ def eval(engine): correct = torch.sum(label[0] == output) correct_sum += correct total_sum += label[0].size(0) - avg_loss = accumulated_loss / engine.schedule.num_steps + avg_loss = accumulated_loss / num_steps return correct_sum, total_sum, avg_loss -def train(engine): +def train_epoch(engine, train_dataloader): engine.train() accumulated_loss = 0 - - for i in range(engine.schedule.num_steps): - output, label, loss = engine.step() + num_steps = len(train_dataloader) + data_iter = iter(train_dataloader) + for i in range(num_steps): + output, label, loss = engine.step(data_iter) accumulated_loss += loss.detach().cpu().numpy() - avg_loss = accumulated_loss / engine.schedule.num_steps + avg_loss = accumulated_loss / num_steps return avg_loss @@ -61,30 +64,17 @@ def train(engine): @pytest.mark.skip("This test should be invoked by test.sh in the same folder as it runs on multiple gpus") def test_2d_parallel_vision_transformer(): # init dist - model, train_dataloader, test_dataloader, criterion, optimizer, schedule, lr_scheduler = colossalai.initialize( - CONFIG_PATH) + engine, train_dataloader, test_dataloader = colossalai.initialize(CONFIG_PATH) logger = get_global_dist_logger() - engine = Engine(model=model, - train_dataloader=train_dataloader, - test_dataloader=test_dataloader, - criterion=criterion, - optimizer=optimizer, - lr_scheduler=lr_scheduler, - schedule=schedule) - - # for param in model.parameters(): - # if isinstance(param, torch.HalfTensor): - # print(param.shape) - logger.info('start training') for epoch in range(gpc.config.num_epochs): - train_loss = train(engine) + train_loss = train_epoch(engine, train_dataloader) logger.info(f'epoch {epoch} - train loss: {train_loss}') if epoch % 2 == 0: - correct_sum, total_sum, eval_loss = eval(engine) + correct_sum, total_sum, eval_loss = eval_epoch(engine, test_dataloader) logger.info( f'epoch {epoch} - eval loss: {eval_loss}, total: {total_sum}, ' f'correct: {correct_sum}, acc: {correct_sum / total_sum}')