mirror of https://github.com/hpcaitech/ColossalAI
Support TP-compatible Torch AMP and Update trainer API (#27)
* Add gradient accumulation, fix lr scheduler
* fix FP16 optimizer and adapted torch amp with tensor parallel (#18)
* fixed bugs in compatibility between torch amp and tensor parallel and performed some minor fixes
* fixed trainer
* Revert "fixed trainer"
This reverts commit 2e0b0b7699
.
* improved consistency between trainer, engine and schedule (#23)
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: ver217 <lhx0217@gmail.com>
pull/28/head
parent
2b05de4c64
commit
3defa32aee
14
README.md
14
README.md
|
@ -42,26 +42,18 @@ pip install -v --no-cache-dir --global-option="--cuda_ext" .
|
||||||
|
|
||||||
```python
|
```python
|
||||||
import colossalai
|
import colossalai
|
||||||
from colossalai.engine import Engine
|
|
||||||
from colossalai.trainer import Trainer
|
from colossalai.trainer import Trainer
|
||||||
from colossalai.core import global_context as gpc
|
from colossalai.core import global_context as gpc
|
||||||
|
|
||||||
model, train_dataloader, test_dataloader, criterion, optimizer, schedule, lr_scheduler = colossalai.initialize()
|
engine, train_dataloader, test_dataloader = colossalai.initialize()
|
||||||
engine = Engine(
|
|
||||||
model=model,
|
|
||||||
criterion=criterion,
|
|
||||||
optimizer=optimizer,
|
|
||||||
lr_scheduler=lr_scheduler,
|
|
||||||
schedule=schedule
|
|
||||||
)
|
|
||||||
|
|
||||||
trainer = Trainer(engine=engine,
|
trainer = Trainer(engine=engine,
|
||||||
hooks_cfg=gpc.config.hooks,
|
|
||||||
verbose=True)
|
verbose=True)
|
||||||
trainer.fit(
|
trainer.fit(
|
||||||
train_dataloader=train_dataloader,
|
train_dataloader=train_dataloader,
|
||||||
test_dataloader=test_dataloader,
|
test_dataloader=test_dataloader,
|
||||||
max_epochs=gpc.config.num_epochs,
|
epochs=gpc.config.num_epochs,
|
||||||
|
hooks_cfg=gpc.config.hooks,
|
||||||
display_progress=True,
|
display_progress=True,
|
||||||
test_interval=5
|
test_interval=5
|
||||||
)
|
)
|
||||||
|
|
|
@ -1,2 +1,10 @@
|
||||||
from .builder import *
|
from .builder import (build_schedule, build_lr_scheduler, build_model, build_optimizer, build_optimizer_wrapper,
|
||||||
|
build_layer, build_loss, build_hooks, build_dataset, build_transform, build_data_sampler,
|
||||||
|
build_gradient_handler)
|
||||||
from .pipeline import ModelInitializer
|
from .pipeline import ModelInitializer
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
'build_schedule', 'build_lr_scheduler', 'build_model', 'build_optimizer', 'build_optimizer_wrapper',
|
||||||
|
'build_layer', 'build_loss', 'build_hooks', 'build_dataset', 'build_transform', 'build_data_sampler',
|
||||||
|
'build_gradient_handler', 'ModelInitializer'
|
||||||
|
]
|
||||||
|
|
|
@ -181,18 +181,6 @@ def build_transform(config):
|
||||||
return build_from_registry(config, TRANSFORMS)
|
return build_from_registry(config, TRANSFORMS)
|
||||||
|
|
||||||
|
|
||||||
def build_pipe_alloc_policy(config):
|
|
||||||
"""Returns a pipeline allocation policy object constructed from `config`.
|
|
||||||
|
|
||||||
:param config: A python dict or a :class:`colossalai.context.Config` object
|
|
||||||
containing information used in the construction of the return object
|
|
||||||
:type config: dict or :class:`colossalai.context.Config`
|
|
||||||
:return: A pipeline allocation policy object
|
|
||||||
:rtype:
|
|
||||||
"""
|
|
||||||
return build_from_registry(config, PIPE_ALLOC_POLICY)
|
|
||||||
|
|
||||||
|
|
||||||
def build_data_sampler(config, dataset):
|
def build_data_sampler(config, dataset):
|
||||||
"""Returns a data sampler object of :class:`colossalai.nn.data.sampler.BaseSampler`
|
"""Returns a data sampler object of :class:`colossalai.nn.data.sampler.BaseSampler`
|
||||||
constructed from `config`.
|
constructed from `config`.
|
||||||
|
@ -235,7 +223,7 @@ def build_optimizer_wrapper(config, optimizer, model=None):
|
||||||
return OPTIMIZER_WRAPPERS.get_module(mod_type)(optimizer, **config_)
|
return OPTIMIZER_WRAPPERS.get_module(mod_type)(optimizer, **config_)
|
||||||
|
|
||||||
|
|
||||||
def build_lr_scheduler(config, optimizer, total_steps, num_steps_per_epoch):
|
def build_lr_scheduler(config, optimizer):
|
||||||
"""Returns a learning rate scheduler object of :class:`torch.optim.lr_scheduler`
|
"""Returns a learning rate scheduler object of :class:`torch.optim.lr_scheduler`
|
||||||
constructed from `config`, `optimizer`, `total_steps` and `num_steps_per_epoch`.
|
constructed from `config`, `optimizer`, `total_steps` and `num_steps_per_epoch`.
|
||||||
|
|
||||||
|
@ -254,9 +242,16 @@ def build_lr_scheduler(config, optimizer, total_steps, num_steps_per_epoch):
|
||||||
"""
|
"""
|
||||||
config_ = config.copy()
|
config_ = config.copy()
|
||||||
mod_type = config_.pop('type')
|
mod_type = config_.pop('type')
|
||||||
# warmup epochs will overwrite warmup steps
|
return LR_SCHEDULERS.get_module(mod_type)(optimizer, **config_)
|
||||||
if 'warmup_epochs' in config_:
|
|
||||||
warmup_epochs = config_.pop('warmup_epochs')
|
|
||||||
config_['warmup_steps'] = int(num_steps_per_epoch * warmup_epochs)
|
def build_schedule(config):
|
||||||
return LR_SCHEDULERS.get_module(mod_type)(optimizer, total_steps, num_steps_per_epoch=num_steps_per_epoch,
|
"""Returns a schedule of :class:`colossalai.engine.schedule.BaseSchedule`.
|
||||||
**config_)
|
|
||||||
|
:param config: A python dict or a :class:`colossalai.context.Config` object
|
||||||
|
containing information used in the construction of the return object
|
||||||
|
:type config: dict or :class:`colossalai.context.Config`
|
||||||
|
:return: An object of :class:`colossalai.engine.schedule.BaseSchedule`
|
||||||
|
:rtype: :class:`colossalai.engine.schedule.BaseSchedule`
|
||||||
|
"""
|
||||||
|
return build_from_registry(config, SCHEDULE)
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
from .amp_type import AMP_TYPE
|
|
||||||
from ._base_engine import Engine
|
from ._base_engine import Engine
|
||||||
from .gradient_handler import *
|
from .gradient_handler import *
|
||||||
from .schedule import *
|
from .schedule import *
|
||||||
|
from .amp import *
|
||||||
|
|
||||||
|
|
||||||
__all__ = ['Engine']
|
__all__ = ['Engine']
|
||||||
|
|
|
@ -1,7 +1,9 @@
|
||||||
#!/usr/bin/env python
|
#!/usr/bin/env python
|
||||||
# -*- encoding: utf-8 -*-
|
# -*- encoding: utf-8 -*-
|
||||||
|
|
||||||
from typing import Optional
|
from torch.nn import Module
|
||||||
|
from torch.nn.modules.loss import _Loss
|
||||||
|
from torch.optim import Optimizer
|
||||||
|
|
||||||
from colossalai.builder import build_gradient_handler
|
from colossalai.builder import build_gradient_handler
|
||||||
from colossalai.context import ParallelMode
|
from colossalai.context import ParallelMode
|
||||||
|
@ -9,89 +11,103 @@ from colossalai.core import global_context as gpc
|
||||||
from colossalai.logging import get_global_dist_logger
|
from colossalai.logging import get_global_dist_logger
|
||||||
from colossalai.nn import (ZeroRedundancyOptimizer_Level_2,
|
from colossalai.nn import (ZeroRedundancyOptimizer_Level_2,
|
||||||
ZeroRedundancyOptimizer_Level_3)
|
ZeroRedundancyOptimizer_Level_3)
|
||||||
from torch.nn import Module
|
from .schedule import BaseSchedule
|
||||||
from torch.nn.modules.loss import _Loss
|
|
||||||
from torch.optim import Optimizer
|
|
||||||
from torch.optim.lr_scheduler import _LRScheduler
|
|
||||||
from torch.utils.data import DataLoader
|
|
||||||
|
|
||||||
from .schedule import BaseSchedule, NoPipelineSchedule
|
|
||||||
|
|
||||||
|
|
||||||
class Engine:
|
class Engine:
|
||||||
"""Basic engine class for training and evaluation. It runs a specific process method
|
"""Basic engine class for training and evaluation. It runs a specific process method
|
||||||
:meth:`step` which is based on the given :attr:`schedule` over each batch of a dataset.
|
:meth:`step` which is based on the given :attr:`schedule` over each batch of a dataset.
|
||||||
|
It controls a iteration in training.
|
||||||
|
|
||||||
:param train_dataloader: Dataloader in training
|
|
||||||
:param test_dataloader: Dataloader in evaluation
|
|
||||||
:param model: The neural network model
|
:param model: The neural network model
|
||||||
:param criterion: Criterion for calculating loss
|
|
||||||
:param optimizer: Optimizer for updating the parameters
|
:param optimizer: Optimizer for updating the parameters
|
||||||
:param lr_scheduler: Learning rate scheduler ajusting learning rate during the training or evaluation
|
:param step_schedule: Running schedule in :meth:`step`
|
||||||
:param schedule: Running schedule in :meth:`step`
|
:param gradient_accumulation: Steps of gradient accumulation
|
||||||
:type train_dataloader: DataLoader, optional
|
:param gradient_clipping: The norm of gradient clipping
|
||||||
:type test_dataloader: DataLoader, optional
|
|
||||||
:type model: Module
|
:type model: Module
|
||||||
:type criterion: _Loss, optional
|
:type optimizer: Optimizer
|
||||||
:type optimizer: Optimizer, optional
|
:type step_schedule: BaseSchedule, optional
|
||||||
:type lr_scheduler: _LRScheduler, optional
|
:type gradient_accumulation: int, optional
|
||||||
:type schedule: BaseSchedule, optional
|
:type gradient_clipping: float, optional
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
train_dataloader: Optional[DataLoader] = None,
|
model: Module,
|
||||||
test_dataloader: Optional[DataLoader] = None,
|
optimizer: Optimizer,
|
||||||
model: Module = None,
|
criterion: _Loss,
|
||||||
criterion: _Loss = None,
|
step_schedule: BaseSchedule,
|
||||||
optimizer: Optimizer = None,
|
gradient_handlers: list = None,
|
||||||
lr_scheduler: Optional[_LRScheduler] = None,
|
gradient_accumulation: int = 1,
|
||||||
schedule: BaseSchedule = None):
|
gradient_clipping: float = 0.0,
|
||||||
self.train_dataloader = train_dataloader
|
):
|
||||||
self.test_dataloader = test_dataloader
|
self._model = model
|
||||||
assert model is not None, "Engine requires a model"
|
self._optimizer = optimizer
|
||||||
self.model = model
|
self._criterion = criterion
|
||||||
self.criterion = criterion
|
self._schedule = step_schedule
|
||||||
self.optimizer = optimizer
|
|
||||||
self.lr_scheduler = lr_scheduler
|
# schedule initialize
|
||||||
self.schedule = schedule if schedule is not None \
|
self._schedule.initialize(model, optimizer)
|
||||||
else NoPipelineSchedule()
|
|
||||||
|
# state
|
||||||
|
self.training = True # default
|
||||||
|
|
||||||
|
# gradient accumulation
|
||||||
|
assert gradient_accumulation > 0, 'gradient accumulation size must be larger than 0'
|
||||||
|
self._grad_accum_size = gradient_accumulation
|
||||||
|
self._grad_clip = gradient_clipping
|
||||||
self._logger = get_global_dist_logger()
|
self._logger = get_global_dist_logger()
|
||||||
|
|
||||||
# build gradient handler
|
# build gradient handler
|
||||||
self._gradient_handlers = []
|
self._gradient_handlers = []
|
||||||
gradient_handler_cfg = []
|
|
||||||
|
|
||||||
if hasattr(gpc.config, 'gradient_handler'):
|
if gradient_handlers is not None:
|
||||||
assert isinstance(gpc.config.gradient_handler, list), \
|
assert isinstance(gradient_handlers, list), \
|
||||||
f'argument gradient_handler_cfg expected type list, ' \
|
f'argument gradient_handler_cfg expected type list, ' \
|
||||||
f'but got type {type(gpc.config.gradient_handler)}'
|
f'but got type {type(gradient_handlers)}'
|
||||||
gradient_handler_cfg = gpc.config.gradient_handler
|
elif isinstance(optimizer, (ZeroRedundancyOptimizer_Level_2,
|
||||||
elif isinstance(self.optimizer, (ZeroRedundancyOptimizer_Level_2,
|
ZeroRedundancyOptimizer_Level_3)):
|
||||||
ZeroRedundancyOptimizer_Level_3)):
|
gradient_handlers = [dict(type='ZeROGradientHandler')]
|
||||||
gradient_handler_cfg = [dict(type='ZeROGradientHandler')]
|
|
||||||
self._logger.info(
|
self._logger.info(
|
||||||
"Training with zero is detected, ZeROGradientHandler is automatically "
|
"Training with zero is detected, ZeROGradientHandler is automatically "
|
||||||
"added even though not specified in the configuration",
|
"added even though not specified in the configuration",
|
||||||
ranks=[0])
|
ranks=[0])
|
||||||
elif gpc.is_initialized(ParallelMode.DATA) and gpc.get_world_size(
|
elif gpc.is_initialized(ParallelMode.DATA) and gpc.get_world_size(
|
||||||
ParallelMode.DATA) > 1:
|
ParallelMode.DATA) > 1:
|
||||||
gradient_handler_cfg = [dict(type='DataParallelGradientHandler')]
|
gradient_handlers = [dict(type='DataParallelGradientHandler')]
|
||||||
self._logger.info(
|
self._logger.info(
|
||||||
"Data parallel training is detected, DataParallelGradientHandler is automatically "
|
"Data parallel training is detected, DataParallelGradientHandler is automatically "
|
||||||
"added even though not specified in the configuration",
|
"added even though not specified in the configuration",
|
||||||
ranks=[0])
|
ranks=[0])
|
||||||
if len(gradient_handler_cfg) == 0:
|
|
||||||
|
if gradient_handlers is None:
|
||||||
self._logger.warning(
|
self._logger.warning(
|
||||||
"No gradient handler is set up, please make sure you do not need "
|
"No gradient handler is set up, please make sure you do not need "
|
||||||
"to all-reduce the gradients after a training step.",
|
"to all-reduce the gradients after a training step.",
|
||||||
ranks=[0])
|
ranks=[0])
|
||||||
for cfg in gradient_handler_cfg:
|
else:
|
||||||
handler = build_gradient_handler(cfg, self.model, self.optimizer)
|
for cfg in gradient_handlers:
|
||||||
self._gradient_handlers.append(handler)
|
handler = build_gradient_handler(cfg, model, optimizer)
|
||||||
|
self._gradient_handlers.append(handler)
|
||||||
|
|
||||||
self.schedule.initialize(self.train_dataloader, self.model,
|
@property
|
||||||
self.criterion, self.optimizer,
|
def model(self):
|
||||||
self.lr_scheduler)
|
return self._model
|
||||||
self.forward_only = False
|
|
||||||
|
@property
|
||||||
|
def optimizer(self):
|
||||||
|
return self._optimizer
|
||||||
|
|
||||||
|
@property
|
||||||
|
def criterion(self):
|
||||||
|
return self._criterion
|
||||||
|
|
||||||
|
@property
|
||||||
|
def schedule(self):
|
||||||
|
return self._schedule
|
||||||
|
|
||||||
|
@property
|
||||||
|
def gradient_accumulation(self):
|
||||||
|
return self._grad_accum_size
|
||||||
|
|
||||||
def handle_gradient(self):
|
def handle_gradient(self):
|
||||||
"""Handles all-reduce operations of gradients across different parallel groups.
|
"""Handles all-reduce operations of gradients across different parallel groups.
|
||||||
|
@ -99,72 +115,62 @@ class Engine:
|
||||||
for handler in self._gradient_handlers:
|
for handler in self._gradient_handlers:
|
||||||
handler.handle_gradient()
|
handler.handle_gradient()
|
||||||
|
|
||||||
def set_dataloader(self, data: DataLoader, train: bool = True):
|
|
||||||
"""Sets dataloader in training or evaluation.
|
|
||||||
|
|
||||||
:param data: Dataloader to be set
|
|
||||||
:param train: Set training dataloader if True, otherwise evaluation dataloader
|
|
||||||
:type data: DataLoader
|
|
||||||
:type train: bool
|
|
||||||
"""
|
|
||||||
if train:
|
|
||||||
self.train_dataloader = data
|
|
||||||
else:
|
|
||||||
self.test_dataloader = data
|
|
||||||
|
|
||||||
def get_model(self):
|
|
||||||
"""Returns the neural network model in the engine.
|
|
||||||
"""
|
|
||||||
return self.model
|
|
||||||
def get_optimizer(self):
|
|
||||||
"""Returns optimizier in the engine.
|
|
||||||
"""
|
|
||||||
return self.optimizer
|
|
||||||
|
|
||||||
def get_lr_scheduler(self):
|
|
||||||
"""Returns the learning rate scheduler in the engine.
|
|
||||||
"""
|
|
||||||
return self.lr_scheduler
|
|
||||||
|
|
||||||
def train(self):
|
def train(self):
|
||||||
"""Sets the model to training mode.
|
"""Sets the model to training mode.
|
||||||
"""
|
"""
|
||||||
self.forward_only = False
|
self.training = True
|
||||||
self.schedule.train(dataloader=self.train_dataloader, mode=True)
|
self._model.train()
|
||||||
|
|
||||||
def eval(self):
|
def eval(self):
|
||||||
"""Sets the model to evaluation mode.
|
"""Sets the model to evaluation mode.
|
||||||
"""
|
"""
|
||||||
self.forward_only = True
|
self.training = False
|
||||||
self.schedule.train(dataloader=self.test_dataloader, mode=False)
|
self._model.eval()
|
||||||
|
|
||||||
def is_train(self):
|
def step(self,
|
||||||
"""Returns True if it is in training, otherwise False.
|
data_iter,
|
||||||
"""
|
is_last_iteration: bool = False,
|
||||||
return not self.forward_only
|
return_loss=True):
|
||||||
|
|
||||||
def get_lr(self):
|
|
||||||
"""Gets current learning rate.
|
|
||||||
"""
|
|
||||||
return self.schedule.get_lr()
|
|
||||||
|
|
||||||
def step(self, return_loss=True):
|
|
||||||
"""A running step based on the schedule. Usually, it runs a training or
|
"""A running step based on the schedule. Usually, it runs a training or
|
||||||
evaluation over a batch of dataset.
|
evaluation over a batch of dataset.
|
||||||
|
|
||||||
|
:param data_iter: Data iterator of the dataset
|
||||||
|
:param is_last_iteration: If True, this iteration is the last iteration in the epoch
|
||||||
:param return_loss: loss will be returned if True
|
:param return_loss: loss will be returned if True
|
||||||
:type return_loss: bool
|
:type data_iter: Iterator
|
||||||
|
:type is_last_iteration: bool, optional
|
||||||
|
:type return_loss: bool, optional
|
||||||
:return: (output, lablel, loss)
|
:return: (output, lablel, loss)
|
||||||
"""
|
"""
|
||||||
self.schedule.zero_grad(forward_only=self.forward_only)
|
if self.training:
|
||||||
|
self._optimizer.zero_grad()
|
||||||
|
|
||||||
output, label, loss = self.schedule.forward_backward_step(
|
# differentiate training and eval with grad accum
|
||||||
forward_only=self.forward_only, return_loss=return_loss)
|
if self.training:
|
||||||
|
for i in range(self._grad_accum_size):
|
||||||
|
output, label, loss = self._schedule.forward_backward_step(
|
||||||
|
data_iter, self._model, self._criterion, self._optimizer,
|
||||||
|
forward_only=False,
|
||||||
|
grad_accum_size=self._grad_accum_size,
|
||||||
|
return_loss=return_loss)
|
||||||
|
|
||||||
if not self.forward_only:
|
if i == self._grad_accum_size - 1:
|
||||||
# all reduce gradients
|
# all reduce gradients
|
||||||
self.handle_gradient()
|
self.handle_gradient()
|
||||||
|
self._schedule.optimizer_step(self._model, self._optimizer, self._grad_clip)
|
||||||
|
else:
|
||||||
|
output, label, loss = self._schedule.forward_backward_step(
|
||||||
|
data_iter, self._model, self._criterion, self._optimizer,
|
||||||
|
forward_only=True,
|
||||||
|
grad_accum_size=1,
|
||||||
|
return_loss=return_loss)
|
||||||
|
|
||||||
self.schedule.step()
|
# consume the remaining dataset left out due to gradient accumulation
|
||||||
|
if is_last_iteration:
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
_ = next(data_iter)
|
||||||
|
except StopIteration:
|
||||||
|
break
|
||||||
|
|
||||||
return output, label, loss
|
return output, label, loss
|
||||||
|
|
|
@ -0,0 +1,2 @@
|
||||||
|
from .grad_scaler import GradScaler
|
||||||
|
from .amp_type import AMP_TYPE
|
|
@ -0,0 +1,577 @@
|
||||||
|
# modified from https://github.com/pytorch/pytorch/blob/master/torch/cuda/amp/grad_scaler.p
|
||||||
|
import torch
|
||||||
|
from collections import defaultdict, abc
|
||||||
|
import warnings
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
from colossalai.context import ParallelMode
|
||||||
|
import torch.distributed as dist
|
||||||
|
from colossalai.core import global_context as gpc
|
||||||
|
|
||||||
|
|
||||||
|
class _MultiDeviceReplicator(object):
|
||||||
|
"""
|
||||||
|
Lazily serves copies of a tensor to requested devices. Copies are cached per-device.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, master_tensor: torch.Tensor) -> None:
|
||||||
|
assert master_tensor.is_cuda or master_tensor.device.type == 'xla'
|
||||||
|
self.master = master_tensor
|
||||||
|
self._per_device_tensors: Dict[torch.device, torch.Tensor] = {}
|
||||||
|
|
||||||
|
def get(self, device) -> torch.Tensor:
|
||||||
|
retval = self._per_device_tensors.get(device, None)
|
||||||
|
if retval is None:
|
||||||
|
retval = self.master.to(
|
||||||
|
device=device, non_blocking=True, copy=True)
|
||||||
|
self._per_device_tensors[device] = retval
|
||||||
|
return retval
|
||||||
|
|
||||||
|
|
||||||
|
# Defines default_factory for GradScaler's _per_optimizer_states defaultdict,
|
||||||
|
# as well as associated "enum" values. Prefers defining these at top level because
|
||||||
|
# - Lambdas can't be pickled, so we don't want to supply a lambda as the factory.
|
||||||
|
# - Defining READY, UNSCALED, STEPPED and _refresh_per_optimizer_state within GradScaler
|
||||||
|
# causes a circular reference, which we'd rather avoid.
|
||||||
|
class OptState(Enum):
|
||||||
|
READY = 0
|
||||||
|
UNSCALED = 1
|
||||||
|
STEPPED = 2
|
||||||
|
|
||||||
|
|
||||||
|
def _refresh_per_optimizer_state():
|
||||||
|
return {"stage": OptState.READY, "found_inf_per_device": {}}
|
||||||
|
|
||||||
|
|
||||||
|
class GradScaler(object):
|
||||||
|
_scale: Optional[torch.Tensor]
|
||||||
|
_grows_tracker: Optional[torch.Tensor]
|
||||||
|
_per_optimizer_states: Dict[int, Dict[str, Any]]
|
||||||
|
"""
|
||||||
|
An instance ``scaler`` of :class:`GradScaler` helps perform the steps of gradient scaling
|
||||||
|
conveniently.
|
||||||
|
|
||||||
|
* ``scaler.scale(loss)`` multiplies a given loss by ``scaler``'s current scale factor.
|
||||||
|
* ``scaler.step(optimizer)`` safely unscales gradients and calls ``optimizer.step()``.
|
||||||
|
* ``scaler.update()`` updates ``scaler``'s scale factor.
|
||||||
|
|
||||||
|
Example::
|
||||||
|
|
||||||
|
# Creates a GradScaler once at the beginning of training.
|
||||||
|
scaler = GradScaler()
|
||||||
|
|
||||||
|
for epoch in epochs:
|
||||||
|
for input, target in data:
|
||||||
|
optimizer.zero_grad()
|
||||||
|
output = model(input)
|
||||||
|
loss = loss_fn(output, target)
|
||||||
|
|
||||||
|
# Scales loss. Calls backward() on scaled loss to create scaled gradients.
|
||||||
|
scaler.scale(loss).backward()
|
||||||
|
|
||||||
|
# scaler.step() first unscales gradients of the optimizer's params.
|
||||||
|
# If gradients don't contain infs/NaNs, optimizer.step() is then called,
|
||||||
|
# otherwise, optimizer.step() is skipped.
|
||||||
|
scaler.step(optimizer)
|
||||||
|
|
||||||
|
# Updates the scale for next iteration.
|
||||||
|
scaler.update()
|
||||||
|
|
||||||
|
See the :ref:`Automatic Mixed Precision examples<amp-examples>` for usage
|
||||||
|
(along with autocasting) in more complex cases like gradient clipping, gradient accumulation, gradient penalty,
|
||||||
|
and multiple losses/optimizers.
|
||||||
|
|
||||||
|
``scaler`` dynamically estimates the scale factor each iteration. To minimize gradient underflow,
|
||||||
|
a large scale factor should be used. However, ``float16`` values can "overflow" (become inf or NaN) if
|
||||||
|
the scale factor is too large. Therefore, the optimal scale factor is the largest factor that can be used
|
||||||
|
without incurring inf or NaN gradient values.
|
||||||
|
``scaler`` approximates the optimal scale factor over time by checking the gradients for infs and NaNs during every
|
||||||
|
``scaler.step(optimizer)`` (or optional separate ``scaler.unscale_(optimizer)``, see :meth:`unscale_`).
|
||||||
|
|
||||||
|
* If infs/NaNs are found, ``scaler.step(optimizer)`` skips the underlying ``optimizer.step()`` (so the params
|
||||||
|
themselves remain uncorrupted) and ``update()`` multiplies the scale by ``backoff_factor``.
|
||||||
|
|
||||||
|
* If no infs/NaNs are found, ``scaler.step(optimizer)`` runs the underlying ``optimizer.step()`` as usual.
|
||||||
|
If ``growth_interval`` unskipped iterations occur consecutively, ``update()`` multiplies the scale by
|
||||||
|
``growth_factor``.
|
||||||
|
|
||||||
|
The scale factor often causes infs/NaNs to appear in gradients for the first few iterations as its
|
||||||
|
value calibrates. ``scaler.step`` will skip the underlying ``optimizer.step()`` for these
|
||||||
|
iterations. After that, step skipping should occur rarely (once every few hundred or thousand iterations).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
init_scale (float, optional, default=2.**16): Initial scale factor.
|
||||||
|
growth_factor (float, optional, default=2.0): Factor by which the scale is multiplied during
|
||||||
|
:meth:`update` if no inf/NaN gradients occur for ``growth_interval`` consecutive iterations.
|
||||||
|
backoff_factor (float, optional, default=0.5): Factor by which the scale is multiplied during
|
||||||
|
:meth:`update` if inf/NaN gradients occur in an iteration.
|
||||||
|
growth_interval (int, optional, default=2000): Number of consecutive iterations without inf/NaN gradients
|
||||||
|
that must occur for the scale to be multiplied by ``growth_factor``.
|
||||||
|
enabled (bool, optional, default=True): If ``False``, disables gradient scaling. :meth:`step` simply
|
||||||
|
invokes the underlying ``optimizer.step()``, and other methods become no-ops.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
init_scale=2.**16,
|
||||||
|
growth_factor=2.0,
|
||||||
|
backoff_factor=0.5,
|
||||||
|
growth_interval=2000,
|
||||||
|
enabled=True):
|
||||||
|
if enabled and not torch.cuda.is_available():
|
||||||
|
warnings.warn(
|
||||||
|
"torch.cuda.amp.GradScaler is enabled, but CUDA is not available. Disabling.")
|
||||||
|
self._enabled = False
|
||||||
|
else:
|
||||||
|
self._enabled = enabled
|
||||||
|
|
||||||
|
if self._enabled:
|
||||||
|
assert growth_factor > 1.0, "The growth factor must be > 1.0."
|
||||||
|
assert backoff_factor < 1.0, "The backoff factor must be < 1.0."
|
||||||
|
|
||||||
|
self._init_scale = init_scale
|
||||||
|
# self._scale will be lazily initialized during the first call to scale()
|
||||||
|
self._scale = None
|
||||||
|
self._growth_factor = growth_factor
|
||||||
|
self._backoff_factor = backoff_factor
|
||||||
|
self._growth_interval = growth_interval
|
||||||
|
self._init_growth_tracker = 0
|
||||||
|
# self._growth_tracker will be lazily initialized during the first call to scale()
|
||||||
|
self._growth_tracker = None
|
||||||
|
self._per_optimizer_states = defaultdict(
|
||||||
|
_refresh_per_optimizer_state)
|
||||||
|
|
||||||
|
def _check_scale_growth_tracker(self, funcname) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
fix = "This may indicate your script did not use scaler.scale(loss or outputs) earlier in the iteration."
|
||||||
|
assert self._scale is not None, "Attempted {} but _scale is None. ".format(
|
||||||
|
funcname) + fix
|
||||||
|
assert self._growth_tracker is not None, "Attempted {} but _growth_tracker is None. ".format(
|
||||||
|
funcname) + fix
|
||||||
|
return (self._scale, self._growth_tracker)
|
||||||
|
|
||||||
|
def _lazy_init_scale_growth_tracker(self, dev):
|
||||||
|
assert self._growth_tracker is None, "_growth_tracker initialized before _scale"
|
||||||
|
self._scale = torch.full(
|
||||||
|
(1,), self._init_scale, dtype=torch.float32, device=dev)
|
||||||
|
self._growth_tracker = torch.full(
|
||||||
|
(1,), self._init_growth_tracker, dtype=torch.int32, device=dev)
|
||||||
|
|
||||||
|
def scale(self, outputs):
|
||||||
|
"""
|
||||||
|
Multiplies ('scales') a tensor or list of tensors by the scale factor.
|
||||||
|
|
||||||
|
Returns scaled outputs. If this instance of :class:`GradScaler` is not enabled, outputs are returned
|
||||||
|
unmodified.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
outputs (Tensor or iterable of Tensors): Outputs to scale.
|
||||||
|
"""
|
||||||
|
if not self._enabled:
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
# Short-circuit for the common case.
|
||||||
|
if isinstance(outputs, torch.Tensor):
|
||||||
|
assert outputs.is_cuda or outputs.device.type == 'xla'
|
||||||
|
if self._scale is None:
|
||||||
|
self._lazy_init_scale_growth_tracker(outputs.device)
|
||||||
|
assert self._scale is not None
|
||||||
|
return outputs * self._scale.to(device=outputs.device, non_blocking=True)
|
||||||
|
|
||||||
|
# Invoke the more complex machinery only if we're treating multiple outputs.
|
||||||
|
# holds a reference that can be overwritten by apply_scale
|
||||||
|
stash: List[_MultiDeviceReplicator] = []
|
||||||
|
|
||||||
|
def apply_scale(val):
|
||||||
|
if isinstance(val, torch.Tensor):
|
||||||
|
assert val.is_cuda or val.device.type == 'xla'
|
||||||
|
if len(stash) == 0:
|
||||||
|
if self._scale is None:
|
||||||
|
self._lazy_init_scale_growth_tracker(val.device)
|
||||||
|
assert self._scale is not None
|
||||||
|
stash.append(_MultiDeviceReplicator(self._scale))
|
||||||
|
return val * stash[0].get(val.device)
|
||||||
|
elif isinstance(val, abc.Iterable):
|
||||||
|
iterable = map(apply_scale, val)
|
||||||
|
if isinstance(val, list) or isinstance(val, tuple):
|
||||||
|
return type(val)(iterable)
|
||||||
|
else:
|
||||||
|
return iterable
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"outputs must be a Tensor or an iterable of Tensors")
|
||||||
|
|
||||||
|
return apply_scale(outputs)
|
||||||
|
|
||||||
|
def _unscale_grads_(self, optimizer, inv_scale, found_inf, allow_fp16):
|
||||||
|
per_device_inv_scale = _MultiDeviceReplicator(inv_scale)
|
||||||
|
per_device_found_inf = _MultiDeviceReplicator(found_inf)
|
||||||
|
|
||||||
|
# To set up _amp_foreach_non_finite_check_and_unscale_, split grads by device and dtype.
|
||||||
|
# There could be hundreds of grads, so we'd like to iterate through them just once.
|
||||||
|
# However, we don't know their devices or dtypes in advance.
|
||||||
|
|
||||||
|
# https://stackoverflow.com/questions/5029934/defaultdict-of-defaultdict
|
||||||
|
# Google says mypy struggles with defaultdicts type annotations.
|
||||||
|
per_device_and_dtype_grads = defaultdict(
|
||||||
|
lambda: defaultdict(list)) # type: ignore[var-annotated]
|
||||||
|
with torch.no_grad():
|
||||||
|
for group in optimizer.param_groups:
|
||||||
|
for param in group["params"]:
|
||||||
|
if param.grad is None:
|
||||||
|
continue
|
||||||
|
if (not allow_fp16) and param.grad.dtype == torch.float16:
|
||||||
|
raise ValueError(
|
||||||
|
"Attempting to unscale FP16 gradients.")
|
||||||
|
if param.grad.is_sparse:
|
||||||
|
# is_coalesced() == False means the sparse grad has values with duplicate indices.
|
||||||
|
# coalesce() deduplicates indices and adds all values that have the same index.
|
||||||
|
# For scaled fp16 values, there's a good chance coalescing will cause overflow,
|
||||||
|
# so we should check the coalesced _values().
|
||||||
|
if param.grad.dtype is torch.float16:
|
||||||
|
param.grad = param.grad.coalesce()
|
||||||
|
to_unscale = param.grad._values()
|
||||||
|
else:
|
||||||
|
to_unscale = param.grad
|
||||||
|
|
||||||
|
# TODO: is there a way to split by device and dtype without appending in the inner loop?
|
||||||
|
per_device_and_dtype_grads[to_unscale.device][to_unscale.dtype].append(
|
||||||
|
to_unscale)
|
||||||
|
|
||||||
|
for device, per_dtype_grads in per_device_and_dtype_grads.items():
|
||||||
|
for grads in per_dtype_grads.values():
|
||||||
|
torch._amp_foreach_non_finite_check_and_unscale_(grads,
|
||||||
|
per_device_found_inf.get(
|
||||||
|
device),
|
||||||
|
per_device_inv_scale.get(device))
|
||||||
|
# For tensor parallel paramters it should be all-reduced over tensor parallel process group
|
||||||
|
if gpc.is_initialized(ParallelMode.TENSOR) and gpc.get_world_size(ParallelMode.TENSOR) > 1:
|
||||||
|
for tensor in per_device_found_inf._per_device_tensors.values():
|
||||||
|
dist.all_reduce(tensor, op=dist.ReduceOp.MAX,
|
||||||
|
group=gpc.get_group(ParallelMode.TENSOR))
|
||||||
|
return per_device_found_inf._per_device_tensors
|
||||||
|
|
||||||
|
def unscale_(self, optimizer):
|
||||||
|
"""
|
||||||
|
Divides ("unscales") the optimizer's gradient tensors by the scale factor.
|
||||||
|
|
||||||
|
:meth:`unscale_` is optional, serving cases where you need to
|
||||||
|
:ref:`modify or inspect gradients<working-with-unscaled-gradients>`
|
||||||
|
between the backward pass(es) and :meth:`step`.
|
||||||
|
If :meth:`unscale_` is not called explicitly, gradients will be unscaled automatically during :meth:`step`.
|
||||||
|
|
||||||
|
Simple example, using :meth:`unscale_` to enable clipping of unscaled gradients::
|
||||||
|
|
||||||
|
...
|
||||||
|
scaler.scale(loss).backward()
|
||||||
|
scaler.unscale_(optimizer)
|
||||||
|
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
|
||||||
|
scaler.step(optimizer)
|
||||||
|
scaler.update()
|
||||||
|
|
||||||
|
Args:
|
||||||
|
optimizer (torch.optim.Optimizer): Optimizer that owns the gradients to be unscaled.
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
:meth:`unscale_` does not incur a CPU-GPU sync.
|
||||||
|
|
||||||
|
.. warning::
|
||||||
|
:meth:`unscale_` should only be called once per optimizer per :meth:`step` call,
|
||||||
|
and only after all gradients for that optimizer's assigned parameters have been accumulated.
|
||||||
|
Calling :meth:`unscale_` twice for a given optimizer between each :meth:`step` triggers a RuntimeError.
|
||||||
|
|
||||||
|
.. warning::
|
||||||
|
:meth:`unscale_` may unscale sparse gradients out of place, replacing the ``.grad`` attribute.
|
||||||
|
"""
|
||||||
|
if not self._enabled:
|
||||||
|
return
|
||||||
|
|
||||||
|
self._check_scale_growth_tracker("unscale_")
|
||||||
|
|
||||||
|
optimizer_state = self._per_optimizer_states[id(optimizer)]
|
||||||
|
|
||||||
|
if optimizer_state["stage"] is OptState.UNSCALED:
|
||||||
|
raise RuntimeError(
|
||||||
|
"unscale_() has already been called on this optimizer since the last update().")
|
||||||
|
elif optimizer_state["stage"] is OptState.STEPPED:
|
||||||
|
raise RuntimeError("unscale_() is being called after step().")
|
||||||
|
|
||||||
|
# FP32 division can be imprecise for certain compile options, so we carry out the reciprocal in FP64.
|
||||||
|
assert self._scale is not None
|
||||||
|
inv_scale = self._scale.double().reciprocal().float()
|
||||||
|
found_inf = torch.full(
|
||||||
|
(1,), 0.0, dtype=torch.float32, device=self._scale.device)
|
||||||
|
|
||||||
|
optimizer_state["found_inf_per_device"] = self._unscale_grads_(
|
||||||
|
optimizer, inv_scale, found_inf, False)
|
||||||
|
optimizer_state["stage"] = OptState.UNSCALED
|
||||||
|
|
||||||
|
def _maybe_opt_step(self, optimizer, optimizer_state, *args, **kwargs):
|
||||||
|
retval = None
|
||||||
|
if not sum(v.item() for v in optimizer_state["found_inf_per_device"].values()):
|
||||||
|
retval = optimizer.step(*args, **kwargs)
|
||||||
|
return retval
|
||||||
|
|
||||||
|
def step(self, optimizer, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
:meth:`step` carries out the following two operations:
|
||||||
|
|
||||||
|
1. Internally invokes ``unscale_(optimizer)`` (unless :meth:`unscale_` was explicitly called for ``optimizer``
|
||||||
|
earlier in the iteration). As part of the :meth:`unscale_`, gradients are checked for infs/NaNs.
|
||||||
|
2. If no inf/NaN gradients are found, invokes ``optimizer.step()`` using the unscaled
|
||||||
|
gradients. Otherwise, ``optimizer.step()`` is skipped to avoid corrupting the params.
|
||||||
|
|
||||||
|
``*args`` and ``**kwargs`` are forwarded to ``optimizer.step()``.
|
||||||
|
|
||||||
|
Returns the return value of ``optimizer.step(*args, **kwargs)``.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
optimizer (torch.optim.Optimizer): Optimizer that applies the gradients.
|
||||||
|
args: Any arguments.
|
||||||
|
kwargs: Any keyword arguments.
|
||||||
|
|
||||||
|
.. warning::
|
||||||
|
Closure use is not currently supported.
|
||||||
|
"""
|
||||||
|
if (not self._enabled):
|
||||||
|
return optimizer.step(*args, **kwargs)
|
||||||
|
|
||||||
|
if "closure" in kwargs:
|
||||||
|
raise RuntimeError(
|
||||||
|
"Closure use is not currently supported if GradScaler is enabled.")
|
||||||
|
|
||||||
|
self._check_scale_growth_tracker("step")
|
||||||
|
|
||||||
|
optimizer_state = self._per_optimizer_states[id(optimizer)]
|
||||||
|
|
||||||
|
if optimizer_state["stage"] is OptState.STEPPED:
|
||||||
|
raise RuntimeError(
|
||||||
|
"step() has already been called since the last update().")
|
||||||
|
|
||||||
|
retval = None
|
||||||
|
|
||||||
|
if (hasattr(optimizer, "_step_supports_amp_scaling") and optimizer._step_supports_amp_scaling):
|
||||||
|
# This optimizer has customized scale-handling logic, so we can call optimizer.step() directly.
|
||||||
|
# The contract with custom optimizers is that their step() should accept an additional,
|
||||||
|
# optional grad_scaler kwarg. We append self to the kwargs so the custom optimizer has full information:
|
||||||
|
# it can query its own state, invoke unscale_ on itself, etc
|
||||||
|
retval = optimizer.step(*args, **dict(kwargs, grad_scaler=self))
|
||||||
|
optimizer_state["stage"] = OptState.STEPPED
|
||||||
|
return retval
|
||||||
|
|
||||||
|
if optimizer_state["stage"] is OptState.READY:
|
||||||
|
self.unscale_(optimizer)
|
||||||
|
|
||||||
|
assert len(optimizer_state["found_inf_per_device"]
|
||||||
|
) > 0, "No inf checks were recorded for this optimizer."
|
||||||
|
|
||||||
|
retval = self._maybe_opt_step(
|
||||||
|
optimizer, optimizer_state, *args, **kwargs)
|
||||||
|
|
||||||
|
optimizer_state["stage"] = OptState.STEPPED
|
||||||
|
|
||||||
|
return retval
|
||||||
|
|
||||||
|
def update(self, new_scale=None):
|
||||||
|
"""
|
||||||
|
Updates the scale factor.
|
||||||
|
|
||||||
|
If any optimizer steps were skipped the scale is multiplied by ``backoff_factor``
|
||||||
|
to reduce it. If ``growth_interval`` unskipped iterations occurred consecutively,
|
||||||
|
the scale is multiplied by ``growth_factor`` to increase it.
|
||||||
|
|
||||||
|
Passing ``new_scale`` sets the new scale value manually. (``new_scale`` is not
|
||||||
|
used directly, it's used to fill GradScaler's internal scale tensor. So if
|
||||||
|
``new_scale`` was a tensor, later in-place changes to that tensor will not further
|
||||||
|
affect the scale GradScaler uses internally.)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
new_scale (float or :class:`torch.cuda.FloatTensor`, optional, default=None): New scale factor.
|
||||||
|
|
||||||
|
.. warning::
|
||||||
|
:meth:`update` should only be called at the end of the iteration, after ``scaler.step(optimizer)`` has
|
||||||
|
been invoked for all optimizers used this iteration.
|
||||||
|
"""
|
||||||
|
if not self._enabled:
|
||||||
|
return
|
||||||
|
|
||||||
|
_scale, _growth_tracker = self._check_scale_growth_tracker("update")
|
||||||
|
|
||||||
|
if new_scale is not None:
|
||||||
|
# Accept a new user-defined scale.
|
||||||
|
if isinstance(new_scale, float):
|
||||||
|
self._scale.fill_(new_scale) # type: ignore[union-attr]
|
||||||
|
else:
|
||||||
|
reason = "new_scale should be a float or a 1-element torch.cuda.FloatTensor with requires_grad=False."
|
||||||
|
# type: ignore[attr-defined]
|
||||||
|
assert isinstance(new_scale, torch.cuda.FloatTensor), reason
|
||||||
|
assert new_scale.numel() == 1, reason
|
||||||
|
assert new_scale.requires_grad is False, reason
|
||||||
|
self._scale.copy_(new_scale) # type: ignore[union-attr]
|
||||||
|
else:
|
||||||
|
# Consume shared inf/nan data collected from optimizers to update the scale.
|
||||||
|
# If all found_inf tensors are on the same device as self._scale, this operation is asynchronous.
|
||||||
|
found_infs = [found_inf.to(device=_scale.device, non_blocking=True)
|
||||||
|
for state in self._per_optimizer_states.values()
|
||||||
|
for found_inf in state["found_inf_per_device"].values()]
|
||||||
|
|
||||||
|
assert len(
|
||||||
|
found_infs) > 0, "No inf checks were recorded prior to update."
|
||||||
|
|
||||||
|
found_inf_combined = found_infs[0]
|
||||||
|
if len(found_infs) > 1:
|
||||||
|
for i in range(1, len(found_infs)):
|
||||||
|
found_inf_combined += found_infs[i]
|
||||||
|
|
||||||
|
torch._amp_update_scale_(_scale,
|
||||||
|
_growth_tracker,
|
||||||
|
found_inf_combined,
|
||||||
|
self._growth_factor,
|
||||||
|
self._backoff_factor,
|
||||||
|
self._growth_interval)
|
||||||
|
|
||||||
|
# To prepare for next iteration, clear the data collected from optimizers this iteration.
|
||||||
|
self._per_optimizer_states = defaultdict(_refresh_per_optimizer_state)
|
||||||
|
|
||||||
|
def _get_scale_async(self):
|
||||||
|
return self._scale
|
||||||
|
|
||||||
|
def get_scale(self):
|
||||||
|
"""
|
||||||
|
Returns a Python float containing the current scale, or 1.0 if scaling is disabled.
|
||||||
|
|
||||||
|
.. warning::
|
||||||
|
:meth:`get_scale` incurs a CPU-GPU sync.
|
||||||
|
"""
|
||||||
|
if self._enabled:
|
||||||
|
return self._init_scale if self._scale is None else self._get_scale_async().item()
|
||||||
|
else:
|
||||||
|
return 1.0
|
||||||
|
|
||||||
|
def get_growth_factor(self):
|
||||||
|
r"""
|
||||||
|
Returns a Python float containing the scale growth factor.
|
||||||
|
"""
|
||||||
|
return self._growth_factor
|
||||||
|
|
||||||
|
def set_growth_factor(self, new_factor):
|
||||||
|
r"""
|
||||||
|
Args:
|
||||||
|
new_scale (float): Value to use as the new scale growth factor.
|
||||||
|
"""
|
||||||
|
self._growth_factor = new_factor
|
||||||
|
|
||||||
|
def get_backoff_factor(self):
|
||||||
|
r"""
|
||||||
|
Returns a Python float containing the scale backoff factor.
|
||||||
|
"""
|
||||||
|
return self._backoff_factor
|
||||||
|
|
||||||
|
def set_backoff_factor(self, new_factor):
|
||||||
|
r"""
|
||||||
|
Args:
|
||||||
|
new_scale (float): Value to use as the new scale backoff factor.
|
||||||
|
"""
|
||||||
|
self._backoff_factor = new_factor
|
||||||
|
|
||||||
|
def get_growth_interval(self):
|
||||||
|
r"""
|
||||||
|
Returns a Python int containing the growth interval.
|
||||||
|
"""
|
||||||
|
return self._growth_interval
|
||||||
|
|
||||||
|
def set_growth_interval(self, new_interval):
|
||||||
|
r"""
|
||||||
|
Args:
|
||||||
|
new_interval (int): Value to use as the new growth interval.
|
||||||
|
"""
|
||||||
|
self._growth_interval = new_interval
|
||||||
|
|
||||||
|
def _get_growth_tracker(self):
|
||||||
|
if self._enabled:
|
||||||
|
return self._init_growth_tracker if self._growth_tracker is None else self._growth_tracker.item()
|
||||||
|
else:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
def is_enabled(self):
|
||||||
|
r"""
|
||||||
|
Returns a bool indicating whether this instance is enabled.
|
||||||
|
"""
|
||||||
|
return self._enabled
|
||||||
|
|
||||||
|
def state_dict(self):
|
||||||
|
r"""
|
||||||
|
Returns the state of the scaler as a :class:`dict`. It contains five entries:
|
||||||
|
|
||||||
|
* ``"scale"`` - a Python float containing the current scale
|
||||||
|
* ``"growth_factor"`` - a Python float containing the current growth factor
|
||||||
|
* ``"backoff_factor"`` - a Python float containing the current backoff factor
|
||||||
|
* ``"growth_interval"`` - a Python int containing the current growth interval
|
||||||
|
* ``"_growth_tracker"`` - a Python int containing the number of recent consecutive unskipped steps.
|
||||||
|
|
||||||
|
If this instance is not enabled, returns an empty dict.
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
If you wish to checkpoint the scaler's state after a particular iteration, :meth:`state_dict`
|
||||||
|
should be called after :meth:`update`.
|
||||||
|
"""
|
||||||
|
return {"scale": self.get_scale(),
|
||||||
|
"growth_factor": self._growth_factor,
|
||||||
|
"backoff_factor": self._backoff_factor,
|
||||||
|
"growth_interval": self._growth_interval,
|
||||||
|
"_growth_tracker": self._get_growth_tracker()} if self._enabled else {}
|
||||||
|
|
||||||
|
def load_state_dict(self, state_dict):
|
||||||
|
r"""
|
||||||
|
Loads the scaler state. If this instance is disabled, :meth:`load_state_dict` is a no-op.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state_dict(dict): scaler state. Should be an object returned from a call to :meth:`state_dict`.
|
||||||
|
"""
|
||||||
|
if not self._enabled:
|
||||||
|
return
|
||||||
|
|
||||||
|
if len(state_dict) == 0:
|
||||||
|
raise RuntimeError("The source state dict is empty, possibly because it was saved "
|
||||||
|
"from a disabled instance of GradScaler.")
|
||||||
|
|
||||||
|
self._init_scale = state_dict["scale"]
|
||||||
|
if self._scale is not None:
|
||||||
|
self._scale.fill_(state_dict["scale"])
|
||||||
|
self._growth_factor = state_dict["growth_factor"]
|
||||||
|
self._backoff_factor = state_dict["backoff_factor"]
|
||||||
|
self._growth_interval = state_dict["growth_interval"]
|
||||||
|
self._init_growth_tracker = state_dict["_growth_tracker"]
|
||||||
|
if self._growth_tracker is not None:
|
||||||
|
self._growth_tracker.fill_(state_dict["_growth_tracker"])
|
||||||
|
|
||||||
|
def __getstate__(self):
|
||||||
|
state = self.__dict__.copy()
|
||||||
|
if self._enabled:
|
||||||
|
assert len(self._per_optimizer_states) == 0, "A GradScaler instance may only be pickled at the beginning "\
|
||||||
|
"of an iteration, or at the end after scaler.update()."
|
||||||
|
# Pickling _scale and _growth_tracker Tensors directly triggers
|
||||||
|
# "warnings.warn("pickle support for Storage will be removed in 1.5..."
|
||||||
|
# so instead, we set the unpickled instance up to reinitialize them lazily.
|
||||||
|
state['_init_scale'] = self.get_scale()
|
||||||
|
state['_init_growth_tracker'] = self._get_growth_tracker()
|
||||||
|
state['_scale'] = None
|
||||||
|
state['_growth_tracker'] = None
|
||||||
|
return state
|
||||||
|
|
||||||
|
def __setstate__(self, state):
|
||||||
|
self.__dict__.update(state)
|
||||||
|
|
||||||
|
def _check_inf_per_device(self, optimizer):
|
||||||
|
_scale, _ = self._check_scale_growth_tracker("_check_inf_per_device")
|
||||||
|
|
||||||
|
dummy_inv_scale = torch.full(
|
||||||
|
(1,), 1.0, dtype=torch.float32, device=_scale.device)
|
||||||
|
found_inf = torch.full(
|
||||||
|
(1,), 0.0, dtype=torch.float32, device=_scale.device)
|
||||||
|
|
||||||
|
self._per_optimizer_states[id(optimizer)]["found_inf_per_device"] = \
|
||||||
|
self._unscale_grads_(optimizer, dummy_inv_scale, found_inf, True)
|
||||||
|
|
||||||
|
return self._per_optimizer_states[id(optimizer)]["found_inf_per_device"]
|
||||||
|
|
||||||
|
def _found_inf_per_device(self, optimizer):
|
||||||
|
return self._per_optimizer_states[id(optimizer)]["found_inf_per_device"]
|
|
@ -5,125 +5,85 @@ from abc import ABC, abstractmethod
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from colossalai.core import global_context as gpc
|
||||||
from colossalai.logging import get_global_dist_logger
|
from colossalai.logging import get_global_dist_logger
|
||||||
from colossalai.utils import get_current_device
|
from colossalai.utils import get_current_device
|
||||||
|
|
||||||
|
|
||||||
class BaseSchedule(ABC):
|
class BaseSchedule(ABC):
|
||||||
"""A basic helper class to control the process of training or evaluation.
|
"""A basic helper class to control the process of training or evaluation.
|
||||||
|
It mainly composes of forward_backward_step for gradient backward and
|
||||||
|
optimizer_step for parameters update.
|
||||||
|
For the convenience to enable FP16, we aggreate all codes that contain the
|
||||||
|
control of FP16 in class schedule.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.initialized = False
|
|
||||||
self.logger = get_global_dist_logger()
|
self.logger = get_global_dist_logger()
|
||||||
|
|
||||||
@property
|
@staticmethod
|
||||||
@abstractmethod
|
def _move_tensor(element):
|
||||||
def num_steps(self):
|
if torch.is_tensor(element):
|
||||||
"""The number of batches in training or evaluation.
|
if not element.is_cuda:
|
||||||
"""
|
return element.to(get_current_device()).detach()
|
||||||
pass
|
return element
|
||||||
|
|
||||||
def initialize(self,
|
|
||||||
dataloader=None,
|
|
||||||
model=None,
|
|
||||||
criterion=None,
|
|
||||||
optimizer=None,
|
|
||||||
lr_scheduler=None):
|
|
||||||
"""Initializes the schedule and set parameters before running.
|
|
||||||
|
|
||||||
:param dataloader: DataLoader in training or evaluation
|
|
||||||
:param model: The neural network model
|
|
||||||
:param criterion: Criterion for calculating loss
|
|
||||||
:param optimizer: Optimizer for updating the parameters
|
|
||||||
:param lr_scheduler: Learning rate scheduler in the process
|
|
||||||
"""
|
|
||||||
self.dataloader = dataloader
|
|
||||||
assert model is not None, "Schedule requires a model"
|
|
||||||
self.model = model
|
|
||||||
assert criterion is not None, "Schedule requires a criterion"
|
|
||||||
self.criterion = criterion
|
|
||||||
assert optimizer is not None, "Schedule requires an optimizer"
|
|
||||||
self.optimizer = optimizer
|
|
||||||
self.lr_scheduler = lr_scheduler
|
|
||||||
self.initialized = True
|
|
||||||
|
|
||||||
def check_initialized(self):
|
|
||||||
"""Checks whether the schedule is initialized.
|
|
||||||
"""
|
|
||||||
assert self.initialized, \
|
|
||||||
'Schedule is not initialized. Call schedule.initialize(...) before using it.'
|
|
||||||
|
|
||||||
def load_batch(self):
|
|
||||||
"""Loads a batch of dataset. It returns the data and labels which are
|
|
||||||
already in the same GPU as where the model's.
|
|
||||||
|
|
||||||
:return: (data, label)
|
|
||||||
:rtype: (Tensor, Tensor)
|
|
||||||
"""
|
|
||||||
self.check_initialized()
|
|
||||||
if self.data_iter is None:
|
|
||||||
raise RuntimeError('Dataloader is not defined.')
|
|
||||||
data, label = next(self.data_iter)
|
|
||||||
return self._move_to_device(data), self._move_to_device(label)
|
|
||||||
|
|
||||||
def _move_to_device(self, data):
|
def _move_to_device(self, data):
|
||||||
if isinstance(data, (
|
if isinstance(data, (tuple, list)):
|
||||||
tuple,
|
data = tuple([self._move_tensor(d) for d in data])
|
||||||
list,
|
|
||||||
)):
|
|
||||||
data = tuple([
|
|
||||||
d.to(get_current_device()).detach() for d in data
|
|
||||||
if torch.is_tensor(d)
|
|
||||||
])
|
|
||||||
elif torch.is_tensor(data):
|
elif torch.is_tensor(data):
|
||||||
data = data.to(get_current_device()).detach()
|
data = data.to(get_current_device()).detach()
|
||||||
return data
|
return data
|
||||||
|
|
||||||
def train(self, dataloader=None, mode=True):
|
def load_batch(self, data_iter):
|
||||||
"""Sets the dataloader to be used and turn the model to
|
"""Loads a batch from data iterator. It returns the data and labels which are
|
||||||
training or evaluation mode.
|
already in the same GPU as where the model's.
|
||||||
|
|
||||||
:param dataloader: Dataloader to be used
|
:return: (data, label)
|
||||||
:param mode: If True, the model will set as training mode. Otherwise, evaluation mode.
|
:rtype: (Tensor, Tensor)
|
||||||
"""
|
"""
|
||||||
self.check_initialized()
|
if data_iter is None:
|
||||||
if mode:
|
raise RuntimeError('Dataloader is not defined.')
|
||||||
self.model.train()
|
data, label = next(data_iter)
|
||||||
else:
|
return self._move_to_device(data), self._move_to_device(label)
|
||||||
self.model.eval()
|
|
||||||
if dataloader is not None:
|
|
||||||
self.dataloader = dataloader
|
|
||||||
self.data_iter = iter(dataloader)
|
|
||||||
|
|
||||||
def zero_grad(self, forward_only=False):
|
def initialize(self, model, optimizer):
|
||||||
"""Cleans gradients with the optimizer.
|
"""Initializes the model and the optimizer before training.
|
||||||
"""
|
This is often used in FP16 training.
|
||||||
if not forward_only:
|
|
||||||
self.check_initialized()
|
|
||||||
self.optimizer.zero_grad()
|
|
||||||
|
|
||||||
def get_lr(self):
|
:param model: The neural network model
|
||||||
"""Returns the current learning rate.
|
:param optimizer: Optimizer for updating the parameters
|
||||||
"""
|
"""
|
||||||
if self.lr_scheduler is not None:
|
return model, optimizer
|
||||||
return self.lr_scheduler.get_lr()[0]
|
|
||||||
else:
|
|
||||||
return self.optimizer.param_groups[0]['lr']
|
|
||||||
|
|
||||||
def step(self):
|
|
||||||
"""Updates the parameters and learning rate with the optimizer.
|
|
||||||
"""
|
|
||||||
self.check_initialized()
|
|
||||||
self.optimizer.step()
|
|
||||||
# update lr scheduler
|
|
||||||
if self.lr_scheduler is not None:
|
|
||||||
self.lr_scheduler.step()
|
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def forward_backward_step(self, forward_only=False, return_loss=True):
|
def forward_backward_step(self,
|
||||||
|
data_iter,
|
||||||
|
model,
|
||||||
|
criterion,
|
||||||
|
optimizer=None,
|
||||||
|
forward_only=False,
|
||||||
|
grad_accum_size: int = 1,
|
||||||
|
return_loss=True):
|
||||||
"""The process function over a batch of dataset for training or evaluation.
|
"""The process function over a batch of dataset for training or evaluation.
|
||||||
|
|
||||||
:param forward_only: If True, the process won't include backward.
|
:param data_iter: Data iterator of the dataset
|
||||||
:param return_loss: If False, the loss won't be returned.
|
:param model: Model used in training or evaluation
|
||||||
|
:param optimizer: Optimizer used in training or evaluation
|
||||||
|
:param criterion: Loss function
|
||||||
|
:param forward_only: If True, the process won't include backward
|
||||||
|
:param grad_accum_size: Steps of gradient accumulation
|
||||||
|
:param return_loss: If False, the loss won't be returned
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def optimizer_step(self, model, optimizer, grad_clipping: float = 0.0):
|
||||||
|
"""Updates the parameters with the optimizer.
|
||||||
|
|
||||||
|
:param model: The neural network model
|
||||||
|
:param optimizer: Optimizer for updating the parameters
|
||||||
|
:param grad_clipping: The norm of gradient clipping
|
||||||
|
:type grad_clipping: float, optional
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
|
@ -4,19 +4,24 @@
|
||||||
try:
|
try:
|
||||||
import apex.amp as apex_amp
|
import apex.amp as apex_amp
|
||||||
except:
|
except:
|
||||||
print('apex is required for mixed precision training')
|
pass
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import torch.cuda.amp as torch_amp
|
import torch.cuda.amp as torch_amp
|
||||||
except:
|
except:
|
||||||
print('PyTorch amp is not supported with the current PyTorch version')
|
pass
|
||||||
|
|
||||||
|
from typing import Iterable
|
||||||
|
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch.optim import Optimizer
|
||||||
|
|
||||||
from colossalai.context import ParallelMode
|
|
||||||
from colossalai.core import global_context as gpc
|
|
||||||
from colossalai.engine.amp_type import AMP_TYPE
|
|
||||||
from colossalai.nn import (ZeroRedundancyOptimizer_Level_2,
|
from colossalai.nn import (ZeroRedundancyOptimizer_Level_2,
|
||||||
ZeroRedundancyOptimizer_Level_3)
|
ZeroRedundancyOptimizer_Level_3)
|
||||||
from ._utils import convert_to_fp16
|
from colossalai.nn.optimizer._utils import clip_grad_norm_fp32
|
||||||
from ._base_schedule import BaseSchedule
|
from ._base_schedule import BaseSchedule
|
||||||
|
from ._utils import convert_to_fp16, convert_to_fp32
|
||||||
|
from ..amp import AMP_TYPE, GradScaler
|
||||||
|
|
||||||
|
|
||||||
class NoPipelineSchedule(BaseSchedule):
|
class NoPipelineSchedule(BaseSchedule):
|
||||||
|
@ -30,6 +35,7 @@ class NoPipelineSchedule(BaseSchedule):
|
||||||
:type amp_type: AMP_TYPE
|
:type amp_type: AMP_TYPE
|
||||||
:type amp_config: dict
|
:type amp_config: dict
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
amp_type: AMP_TYPE = None,
|
amp_type: AMP_TYPE = None,
|
||||||
|
@ -41,12 +47,6 @@ class NoPipelineSchedule(BaseSchedule):
|
||||||
assert amp_type is None or isinstance(amp_type, AMP_TYPE), \
|
assert amp_type is None or isinstance(amp_type, AMP_TYPE), \
|
||||||
'unrecognised value for argument fp16, it can only be None, torch or apex'
|
'unrecognised value for argument fp16, it can only be None, torch or apex'
|
||||||
|
|
||||||
# LSG: check compatibility
|
|
||||||
# LSG: torch.cuda.amp and apex.amp cannot be used for tensor parallel
|
|
||||||
if gpc.is_initialized(ParallelMode.TENSOR) and gpc.get_world_size(
|
|
||||||
ParallelMode.TENSOR) > 1:
|
|
||||||
assert amp_type != AMP_TYPE.TORCH and amp_type != AMP_TYPE.APEX, \
|
|
||||||
'You can only AMP_TYPE.PARALLEL for tensor parallel training'
|
|
||||||
self.use_zero_level_2_3 = False
|
self.use_zero_level_2_3 = False
|
||||||
|
|
||||||
if amp_type is not None:
|
if amp_type is not None:
|
||||||
|
@ -79,107 +79,110 @@ class NoPipelineSchedule(BaseSchedule):
|
||||||
self.fp16 = False
|
self.fp16 = False
|
||||||
self.amp_type = None
|
self.amp_type = None
|
||||||
|
|
||||||
@property
|
def initialize(self, model: nn.Module, optimizer: Optimizer):
|
||||||
def num_steps(self):
|
if isinstance(optimizer, (ZeroRedundancyOptimizer_Level_2,
|
||||||
return len(self.dataloader)
|
ZeroRedundancyOptimizer_Level_3)):
|
||||||
|
|
||||||
def initialize(self,
|
|
||||||
dataloader,
|
|
||||||
model,
|
|
||||||
criterion,
|
|
||||||
optimizer,
|
|
||||||
lr_scheduler=None):
|
|
||||||
super().initialize(dataloader,
|
|
||||||
model,
|
|
||||||
criterion,
|
|
||||||
optimizer,
|
|
||||||
lr_scheduler=lr_scheduler)
|
|
||||||
if isinstance(self.optimizer, (ZeroRedundancyOptimizer_Level_2,
|
|
||||||
ZeroRedundancyOptimizer_Level_3)):
|
|
||||||
self.use_zero_level_2_3 = True
|
self.use_zero_level_2_3 = True
|
||||||
assert self.amp_type != AMP_TYPE.PARALLEL, 'ZeRO Level 2 and 3 are mutually exclusive with AMP_TYPE.PARALLEL'
|
assert self.amp_type != AMP_TYPE.PARALLEL, \
|
||||||
|
'ZeRO Level 2 and 3 are mutually exclusive with AMP_TYPE.PARALLEL'
|
||||||
|
|
||||||
if self.fp16:
|
if self.fp16:
|
||||||
if self.amp_type == AMP_TYPE.TORCH:
|
if self.amp_type == AMP_TYPE.TORCH:
|
||||||
self._torch_amp_scaler = torch_amp.GradScaler(**self.amp_cfg)
|
self._torch_amp_scaler = GradScaler(**self.amp_cfg)
|
||||||
elif self.amp_type == AMP_TYPE.APEX:
|
elif self.amp_type == AMP_TYPE.APEX:
|
||||||
self.model, self.optimizer = apex_amp.initialize(
|
model, optimizer = apex_amp.initialize(model, optimizer, **self.amp_cfg)
|
||||||
self.model, self.optimizer, **self.amp_cfg)
|
|
||||||
|
|
||||||
def forward_backward_step(self, forward_only=False, return_loss=True):
|
return model, optimizer
|
||||||
|
|
||||||
|
def forward_backward_step(self,
|
||||||
|
data_iter: Iterable,
|
||||||
|
model: nn.Module,
|
||||||
|
criterion: nn.modules.loss._Loss,
|
||||||
|
optimizer: Optimizer = None,
|
||||||
|
forward_only: bool = False,
|
||||||
|
grad_accum_size: int = 1,
|
||||||
|
return_loss: bool = True):
|
||||||
"""The process function that loads loads a batch of dataset and feeds it to the model.
|
"""The process function that loads loads a batch of dataset and feeds it to the model.
|
||||||
The returned labels and loss will None if :attr:`return_loss` is False.
|
The returned labels and loss will None if :attr:`return_loss` is False.
|
||||||
|
|
||||||
|
:param data_iter: Data iterator of the dataloader, e.g. iter(dataloader)
|
||||||
|
:param model: Model for training and inference
|
||||||
|
:param criterion: Loss function for training
|
||||||
|
:param optimizer: Optimizer used for training
|
||||||
|
:param forward_only: If True, the model is run for the forward pass, else back propagation will be executed
|
||||||
|
:param grad_accum_size: The number of iterations for gradient accumulation
|
||||||
|
:param return_loss: Loss will be returned if True
|
||||||
|
:type data_iter: Iterator
|
||||||
|
:type model: torch.nn.Module
|
||||||
|
:type criterion: torch.nn.modules.loss._Loss
|
||||||
|
:type optimizer: torch.optim.Optimizer
|
||||||
|
:type forward_only: bool, optional
|
||||||
|
:type grad_accum_size: int
|
||||||
|
:type return_loss: bool, optional
|
||||||
:return: (output, label, loss)
|
:return: (output, label, loss)
|
||||||
"""
|
"""
|
||||||
assert forward_only or return_loss, \
|
assert forward_only or return_loss, \
|
||||||
'The argument \'return_loss\' has to be True when \'forward_only\' is False, but got False.'
|
'The argument \'return_loss\' has to be True when \'forward_only\' is False, but got False.'
|
||||||
|
|
||||||
data, label = self.load_batch()
|
data, label = self.load_batch(data_iter)
|
||||||
loss = None
|
loss = None
|
||||||
|
|
||||||
# LSG: leave for debug, make sure dataloader is deterministic
|
|
||||||
# if forward_only:
|
|
||||||
# img = data[0]
|
|
||||||
# rank = gpc.get_local_rank(ParallelMode.DATA)
|
|
||||||
# world_size = gpc.get_world_size(ParallelMode.DATA)
|
|
||||||
# group = gpc.get_group(ParallelMode.DATA)
|
|
||||||
# input_list = [img.clone() for _ in range(world_size)]
|
|
||||||
# output_list = [torch.empty_like(img) for _ in range(world_size)]
|
|
||||||
# output_list[rank] = img.clone()
|
|
||||||
# dist.all_to_all(output_tensor_list=output_list, input_tensor_list=input_list, group=group)
|
|
||||||
# assert torch.equal(output_list[0], output_list[1]) # and torch.equal(output_list[1], output_list[2])
|
|
||||||
|
|
||||||
# forward
|
# forward
|
||||||
if self.fp16 and self.amp_type == AMP_TYPE.TORCH:
|
if self.fp16 and self.amp_type == AMP_TYPE.TORCH:
|
||||||
with torch_amp.autocast():
|
with torch_amp.autocast():
|
||||||
output = self.model(*data)
|
output = model(*data)
|
||||||
if not isinstance(output, (tuple, list)):
|
if not isinstance(output, (tuple, list)):
|
||||||
output = (output,)
|
output = (output,)
|
||||||
if return_loss:
|
if return_loss:
|
||||||
loss = self.criterion(*output, *label)
|
loss = criterion(*output, *label)
|
||||||
else:
|
else:
|
||||||
if self.use_zero_level_2_3 or self.amp_type == AMP_TYPE.PARALLEL:
|
if self.use_zero_level_2_3 or self.amp_type == AMP_TYPE.PARALLEL:
|
||||||
data = convert_to_fp16(data)
|
data = convert_to_fp16(data)
|
||||||
|
|
||||||
output = self.model(*data)
|
output = model(*data)
|
||||||
|
|
||||||
|
if self.use_zero_level_2_3 or self.amp_type == AMP_TYPE.PARALLEL:
|
||||||
|
output = convert_to_fp32(output)
|
||||||
|
|
||||||
if not isinstance(output, (tuple, list)):
|
if not isinstance(output, (tuple, list)):
|
||||||
output = (output,)
|
output = (output,)
|
||||||
if return_loss:
|
if return_loss:
|
||||||
loss = self.criterion(*output, *label)
|
loss = criterion(*output, *label)
|
||||||
|
|
||||||
|
loss /= grad_accum_size
|
||||||
|
|
||||||
if not forward_only:
|
if not forward_only:
|
||||||
# backward
|
# backward
|
||||||
if self.use_zero_level_2_3:
|
if self.use_zero_level_2_3:
|
||||||
self.optimizer.backward(loss)
|
optimizer.backward(loss)
|
||||||
elif self.fp16:
|
elif self.fp16:
|
||||||
if self.amp_type == AMP_TYPE.APEX:
|
if self.amp_type == AMP_TYPE.APEX:
|
||||||
with apex_amp.scale_loss(loss,
|
with apex_amp.scale_loss(loss, optimizer) as scaled_loss:
|
||||||
self.optimizer) as scaled_loss:
|
|
||||||
scaled_loss.backward()
|
scaled_loss.backward()
|
||||||
elif self.amp_type == AMP_TYPE.TORCH:
|
elif self.amp_type == AMP_TYPE.TORCH:
|
||||||
self._torch_amp_scaler.scale(loss).backward()
|
self._torch_amp_scaler.scale(loss).backward()
|
||||||
elif self.amp_type == AMP_TYPE.PARALLEL:
|
elif self.amp_type == AMP_TYPE.PARALLEL:
|
||||||
loss = self.optimizer.scale_loss(loss)
|
loss = optimizer.scale_loss(loss)
|
||||||
loss.backward()
|
loss.backward()
|
||||||
# scale back to display the original value in logs
|
# scale back to display the original value in logs
|
||||||
loss.div_(self.optimizer.grad_scaler.scale)
|
loss.div_(optimizer.grad_scaler.scale)
|
||||||
else:
|
else:
|
||||||
loss.backward()
|
loss.backward()
|
||||||
|
|
||||||
if return_loss:
|
if return_loss:
|
||||||
return output, label, loss
|
return output, label, loss * grad_accum_size
|
||||||
else:
|
else:
|
||||||
return output, None, None
|
return output, None, None
|
||||||
|
|
||||||
def step(self):
|
def optimizer_step(self, model: nn.Module, optimizer: Optimizer, grad_clipping: float = 0.0):
|
||||||
# step optimizer
|
# step optimizer
|
||||||
if self.fp16 and self.amp_type == AMP_TYPE.TORCH:
|
if self.fp16 and self.amp_type == AMP_TYPE.TORCH:
|
||||||
self._torch_amp_scaler.step(self.optimizer)
|
if grad_clipping > 0.0:
|
||||||
|
self._torch_amp_scaler.unscale_(optimizer)
|
||||||
|
clip_grad_norm_fp32(model.parameters(), grad_clipping)
|
||||||
|
self._torch_amp_scaler.step(optimizer)
|
||||||
self._torch_amp_scaler.update()
|
self._torch_amp_scaler.update()
|
||||||
else:
|
else:
|
||||||
self.optimizer.step()
|
if not self.fp16 and not self.use_zero_level_2_3 and grad_clipping > 0.0:
|
||||||
|
clip_grad_norm_fp32(model.parameters(), grad_clipping)
|
||||||
# update lr scheduler
|
optimizer.step()
|
||||||
if self.lr_scheduler is not None:
|
|
||||||
self.lr_scheduler.step()
|
|
||||||
|
|
|
@ -15,7 +15,7 @@ from colossalai.nn import (ZeroRedundancyOptimizer_Level_2,
|
||||||
from colossalai.utils import get_current_device
|
from colossalai.utils import get_current_device
|
||||||
from ._base_schedule import BaseSchedule
|
from ._base_schedule import BaseSchedule
|
||||||
from ._utils import convert_to_fp16
|
from ._utils import convert_to_fp16
|
||||||
from ..amp_type import AMP_TYPE
|
from ..amp import AMP_TYPE
|
||||||
|
|
||||||
|
|
||||||
def squeeze(x: Union[Tensor, tuple, list]):
|
def squeeze(x: Union[Tensor, tuple, list]):
|
||||||
|
@ -93,12 +93,11 @@ class PipelineSchedule(BaseSchedule):
|
||||||
)
|
)
|
||||||
|
|
||||||
# Pipeline schedule just puts data in memory
|
# Pipeline schedule just puts data in memory
|
||||||
def load_batch(self):
|
def load_batch(self, data_iter):
|
||||||
self.check_initialized()
|
if data_iter is None:
|
||||||
if self.data_iter is None:
|
|
||||||
raise RuntimeError('Dataloader is not defined.')
|
raise RuntimeError('Dataloader is not defined.')
|
||||||
self.batch_pos = 0
|
self.batch_pos = 0
|
||||||
data, label = next(self.data_iter)
|
data, label = next(data_iter)
|
||||||
self.batch_data, self.batch_label = \
|
self.batch_data, self.batch_label = \
|
||||||
self._move_to_device(data), self._move_to_device(label)
|
self._move_to_device(data), self._move_to_device(label)
|
||||||
batch_size = self.batch_data.shape[0]
|
batch_size = self.batch_data.shape[0]
|
||||||
|
@ -117,23 +116,8 @@ class PipelineSchedule(BaseSchedule):
|
||||||
self.batch_pos += self.microbatch_size
|
self.batch_pos += self.microbatch_size
|
||||||
return (data,), (label,)
|
return (data,), (label,)
|
||||||
|
|
||||||
@property
|
def initialize(self, model, optimizer):
|
||||||
def num_steps(self):
|
if isinstance(optimizer, (ZeroRedundancyOptimizer_Level_2, ZeroRedundancyOptimizer_Level_3)):
|
||||||
return len(self.dataloader)
|
|
||||||
|
|
||||||
def initialize(self,
|
|
||||||
dataloader,
|
|
||||||
model,
|
|
||||||
criterion,
|
|
||||||
optimizer,
|
|
||||||
lr_scheduler=None):
|
|
||||||
super().initialize(dataloader,
|
|
||||||
model,
|
|
||||||
criterion,
|
|
||||||
optimizer,
|
|
||||||
lr_scheduler=lr_scheduler)
|
|
||||||
if isinstance(self.optimizer, (ZeroRedundancyOptimizer_Level_2,
|
|
||||||
ZeroRedundancyOptimizer_Level_3)):
|
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
"Pipeline schedule is currently not compatible with ZeRO Level 2 and Level 3"
|
"Pipeline schedule is currently not compatible with ZeRO Level 2 and Level 3"
|
||||||
)
|
)
|
||||||
|
@ -145,7 +129,8 @@ class PipelineSchedule(BaseSchedule):
|
||||||
'default tensor dtype is set to torch.half for fp16 training',
|
'default tensor dtype is set to torch.half for fp16 training',
|
||||||
ranks=[0])
|
ranks=[0])
|
||||||
|
|
||||||
def forward_step(self, input_tensor, return_tensors, return_loss=True):
|
def forward_step(self, model, criterion, input_tensor, return_tensors,
|
||||||
|
grad_accum_size, return_loss=True):
|
||||||
"""Forward step for passed-in model. If it is the first stage, the input tensor
|
"""Forward step for passed-in model. If it is the first stage, the input tensor
|
||||||
is obtained from data_iterator, otherwise the passed-in input_tensor is used.
|
is obtained from data_iterator, otherwise the passed-in input_tensor is used.
|
||||||
Returns output tensor. This is a helper function and can be ignored by users.
|
Returns output tensor. This is a helper function and can be ignored by users.
|
||||||
|
@ -156,14 +141,14 @@ class PipelineSchedule(BaseSchedule):
|
||||||
if self.amp_type == AMP_TYPE.PARALLEL:
|
if self.amp_type == AMP_TYPE.PARALLEL:
|
||||||
input_tensor = convert_to_fp16(input_tensor)
|
input_tensor = convert_to_fp16(input_tensor)
|
||||||
input_tensor = squeeze(input_tensor)
|
input_tensor = squeeze(input_tensor)
|
||||||
output_tensor = self.model(input_tensor)
|
output_tensor = model(input_tensor)
|
||||||
output_tensor = squeeze(output_tensor)
|
output_tensor = squeeze(output_tensor)
|
||||||
|
|
||||||
if gpc.is_last_rank(ParallelMode.PIPELINE):
|
if gpc.is_last_rank(ParallelMode.PIPELINE):
|
||||||
if return_loss:
|
if return_loss:
|
||||||
input_tensor, label = self.load_micro_batch()
|
input_tensor, label = self.load_micro_batch()
|
||||||
loss_reduced = self.criterion(output_tensor, *
|
loss_reduced = criterion(output_tensor, *label) \
|
||||||
label) / self.num_microbatches
|
/ (self.num_microbatches * grad_accum_size)
|
||||||
return_tensors.append(
|
return_tensors.append(
|
||||||
tuple((output_tensor, label[0], loss_reduced)))
|
tuple((output_tensor, label[0], loss_reduced)))
|
||||||
return loss_reduced
|
return loss_reduced
|
||||||
|
@ -174,7 +159,7 @@ class PipelineSchedule(BaseSchedule):
|
||||||
else:
|
else:
|
||||||
return output_tensor
|
return output_tensor
|
||||||
|
|
||||||
def backward_step(self, input_tensor, output_tensor, output_tensor_grad):
|
def backward_step(self, optimizer, input_tensor, output_tensor, output_tensor_grad):
|
||||||
"""Backward step through the passed-in output tensor. If it is the last stage, the
|
"""Backward step through the passed-in output tensor. If it is the last stage, the
|
||||||
output_tensor_grad is None, otherwise it is the gradients with respect to stage's output tensor.
|
output_tensor_grad is None, otherwise it is the gradients with respect to stage's output tensor.
|
||||||
Returns the gradients with respect to the input tensor (None if first stage).
|
Returns the gradients with respect to the input tensor (None if first stage).
|
||||||
|
@ -187,7 +172,7 @@ class PipelineSchedule(BaseSchedule):
|
||||||
|
|
||||||
# Backward pass.
|
# Backward pass.
|
||||||
if output_tensor_grad is None and self.amp_type == AMP_TYPE.PARALLEL:
|
if output_tensor_grad is None and self.amp_type == AMP_TYPE.PARALLEL:
|
||||||
output_tensor = self.optimizer.scale_loss(output_tensor)
|
output_tensor = optimizer.scale_loss(output_tensor)
|
||||||
torch.autograd.backward(output_tensor, grad_tensors=output_tensor_grad)
|
torch.autograd.backward(output_tensor, grad_tensors=output_tensor_grad)
|
||||||
|
|
||||||
# Collect the grad of the input_tensor.
|
# Collect the grad of the input_tensor.
|
||||||
|
@ -197,17 +182,24 @@ class PipelineSchedule(BaseSchedule):
|
||||||
|
|
||||||
return input_tensor_grad
|
return input_tensor_grad
|
||||||
|
|
||||||
def forward_backward_step(self, forward_only=True, return_loss=True):
|
def forward_backward_step(self,
|
||||||
|
data_iter,
|
||||||
|
model,
|
||||||
|
criterion,
|
||||||
|
optimizer=None,
|
||||||
|
forward_only=False,
|
||||||
|
grad_accum_size: int = 1,
|
||||||
|
return_loss=True):
|
||||||
"""Runs non-interleaved 1F1B schedule, with communication between pipeline stages.
|
"""Runs non-interleaved 1F1B schedule, with communication between pipeline stages.
|
||||||
Returns a tuple with losses if the last stage, an empty tuple otherwise.
|
Returns a tuple with losses if the last stage, an empty tuple otherwise.
|
||||||
|
|
||||||
:return: (output, label, loss)
|
:return: (output, label, loss)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
assert forward_only or return_loss, \
|
assert forward_only or return_loss, \
|
||||||
'The argument \'return_loss\' has to be True when \'forward_only\' is False, but got False.'
|
'The argument \'return_loss\' has to be True when \'forward_only\' is False, but got False.'
|
||||||
|
|
||||||
self.load_batch()
|
self.load_batch(data_iter)
|
||||||
num_warmup_microbatches = \
|
num_warmup_microbatches = \
|
||||||
(gpc.get_world_size(ParallelMode.PIPELINE) -
|
(gpc.get_world_size(ParallelMode.PIPELINE) -
|
||||||
gpc.get_local_rank(ParallelMode.PIPELINE) - 1)
|
gpc.get_local_rank(ParallelMode.PIPELINE) - 1)
|
||||||
|
@ -233,9 +225,11 @@ class PipelineSchedule(BaseSchedule):
|
||||||
if not gpc.is_first_rank(ParallelMode.PIPELINE):
|
if not gpc.is_first_rank(ParallelMode.PIPELINE):
|
||||||
ft_shape = recv_tensor_meta(ft_shape)
|
ft_shape = recv_tensor_meta(ft_shape)
|
||||||
input_tensor = recv_forward(ft_shape)
|
input_tensor = recv_forward(ft_shape)
|
||||||
output_tensor = self.forward_step(input_tensor,
|
output_tensor = self.forward_step(
|
||||||
return_tensors,
|
model, criterion,
|
||||||
return_loss=return_loss)
|
input_tensor, return_tensors,
|
||||||
|
grad_accum_size, return_loss=return_loss
|
||||||
|
)
|
||||||
if not gpc.is_last_rank(ParallelMode.PIPELINE):
|
if not gpc.is_last_rank(ParallelMode.PIPELINE):
|
||||||
bt_shape = output_tensor.shape
|
bt_shape = output_tensor.shape
|
||||||
fs_checker = send_tensor_meta(output_tensor, fs_checker)
|
fs_checker = send_tensor_meta(output_tensor, fs_checker)
|
||||||
|
@ -257,9 +251,11 @@ class PipelineSchedule(BaseSchedule):
|
||||||
for i in range(num_microbatches_remaining):
|
for i in range(num_microbatches_remaining):
|
||||||
last_iteration = (i == (num_microbatches_remaining - 1))
|
last_iteration = (i == (num_microbatches_remaining - 1))
|
||||||
|
|
||||||
output_tensor = self.forward_step(input_tensor,
|
output_tensor = self.forward_step(
|
||||||
return_tensors,
|
model, criterion,
|
||||||
return_loss=return_loss)
|
input_tensor, return_tensors,
|
||||||
|
grad_accum_size, return_loss=return_loss
|
||||||
|
)
|
||||||
if forward_only:
|
if forward_only:
|
||||||
send_forward(output_tensor)
|
send_forward(output_tensor)
|
||||||
|
|
||||||
|
@ -279,9 +275,11 @@ class PipelineSchedule(BaseSchedule):
|
||||||
input_tensor = input_tensors.pop(0)
|
input_tensor = input_tensors.pop(0)
|
||||||
output_tensor = output_tensors.pop(0)
|
output_tensor = output_tensors.pop(0)
|
||||||
|
|
||||||
input_tensor_grad = self.backward_step(input_tensor,
|
input_tensor_grad = self.backward_step(
|
||||||
output_tensor,
|
optimizer,
|
||||||
output_tensor_grad)
|
input_tensor, output_tensor,
|
||||||
|
output_tensor_grad
|
||||||
|
)
|
||||||
|
|
||||||
if last_iteration:
|
if last_iteration:
|
||||||
input_tensor = None
|
input_tensor = None
|
||||||
|
@ -298,9 +296,11 @@ class PipelineSchedule(BaseSchedule):
|
||||||
|
|
||||||
output_tensor_grad = recv_backward(bt_shape)
|
output_tensor_grad = recv_backward(bt_shape)
|
||||||
|
|
||||||
input_tensor_grad = self.backward_step(input_tensor,
|
input_tensor_grad = self.backward_step(
|
||||||
output_tensor,
|
optimizer,
|
||||||
output_tensor_grad)
|
input_tensor, output_tensor,
|
||||||
|
output_tensor_grad
|
||||||
|
)
|
||||||
|
|
||||||
send_backward(input_tensor_grad)
|
send_backward(input_tensor_grad)
|
||||||
|
|
||||||
|
@ -309,8 +309,11 @@ class PipelineSchedule(BaseSchedule):
|
||||||
output, label, loss = tuple(map(list, zip(*return_tensors)))
|
output, label, loss = tuple(map(list, zip(*return_tensors)))
|
||||||
return (torch.cat(output, dim=0),
|
return (torch.cat(output, dim=0),
|
||||||
torch.cat(label, dim=0),
|
torch.cat(label, dim=0),
|
||||||
sum(loss))
|
sum(loss) * grad_accum_size)
|
||||||
else:
|
else:
|
||||||
return tuple((torch.cat(return_tensors, dim=0), None, None))
|
return tuple((torch.cat(return_tensors, dim=0), None, None))
|
||||||
else:
|
else:
|
||||||
return tuple((None, None, None))
|
return tuple((None, None, None))
|
||||||
|
|
||||||
|
def optimizer_step(self, model, optimizer, grad_clipping: float = 0.0):
|
||||||
|
optimizer.step()
|
||||||
|
|
|
@ -14,3 +14,14 @@ def convert_to_fp16(data: Union[Tensor, List[Tensor]]):
|
||||||
else:
|
else:
|
||||||
raise TypeError(f"Expected argument 'data' to be a Tensor or a list/tuple of Tensor, but got {type(data)}")
|
raise TypeError(f"Expected argument 'data' to be a Tensor or a list/tuple of Tensor, but got {type(data)}")
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
|
||||||
|
def convert_to_fp32(data: Union[Tensor, List[Tensor]]):
|
||||||
|
if isinstance(data, Tensor):
|
||||||
|
ret = data.float()
|
||||||
|
elif isinstance(data, (list, tuple)):
|
||||||
|
ret = [val.float() for val in data]
|
||||||
|
else:
|
||||||
|
raise TypeError(f"Expected argument 'data' to be a Tensor or a list/tuple of Tensor, but got {type(data)}")
|
||||||
|
return ret
|
||||||
|
|
||||||
|
|
|
@ -6,18 +6,20 @@ import pprint
|
||||||
import random
|
import random
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Callable, Iterable, Optional, Union
|
from typing import Callable, Iterable, Optional, Union
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
from colossalai.engine import AMP_TYPE, NoPipelineSchedule, PipelineSchedule
|
from colossalai.engine import AMP_TYPE, NoPipelineSchedule, PipelineSchedule
|
||||||
|
from colossalai.engine import Engine
|
||||||
from colossalai.logging import get_global_dist_logger, init_global_dist_logger
|
from colossalai.logging import get_global_dist_logger, init_global_dist_logger
|
||||||
from colossalai.nn import DataParallelSampler
|
from colossalai.nn import DataParallelSampler
|
||||||
from colossalai.nn.model.base_model import BaseModel
|
from colossalai.nn.model.base_model import BaseModel
|
||||||
from .builder import (ModelInitializer, build_dataset, build_loss,
|
from .builder import (ModelInitializer, build_dataset, build_loss,
|
||||||
build_lr_scheduler, build_model, build_optimizer,
|
build_model, build_optimizer,
|
||||||
build_optimizer_wrapper)
|
build_optimizer_wrapper, build_schedule)
|
||||||
from .context import Config, ParallelMode
|
from .context import Config, ParallelMode
|
||||||
from .core import global_context as gpc
|
from .core import global_context as gpc
|
||||||
from .utils import get_current_device, sync_model_param_in_dp
|
from .utils import get_current_device, sync_model_param_in_dp
|
||||||
|
@ -182,7 +184,7 @@ def initialize(config: Union[str, dict] = None,
|
||||||
backend: str = None,
|
backend: str = None,
|
||||||
train_dataloader: Optional[Union[Iterable, Callable]] = None,
|
train_dataloader: Optional[Union[Iterable, Callable]] = None,
|
||||||
test_dataloader: Optional[Union[Iterable, Callable]] = None,
|
test_dataloader: Optional[Union[Iterable, Callable]] = None,
|
||||||
):
|
) -> Tuple[Engine, DataLoader, DataLoader]:
|
||||||
'''Core function that initializes distributed environment, logger, cudnn, data, model, loss function, optimizer, and lr_scheduler(their configs are in gpc.config).
|
'''Core function that initializes distributed environment, logger, cudnn, data, model, loss function, optimizer, and lr_scheduler(their configs are in gpc.config).
|
||||||
|
|
||||||
:param config: config file or config file path are both acceptable
|
:param config: config file or config file path are both acceptable
|
||||||
|
@ -201,7 +203,7 @@ def initialize(config: Union[str, dict] = None,
|
||||||
:type train_dataloader: Optional[Union[Iterable, Callable]], optional
|
:type train_dataloader: Optional[Union[Iterable, Callable]], optional
|
||||||
:param test_dataloader: If None, the config is used to build a dataloder; Else, it should be a dataloader object or a function with no arguments which can build a dataloader, defaults to None
|
:param test_dataloader: If None, the config is used to build a dataloder; Else, it should be a dataloader object or a function with no arguments which can build a dataloader, defaults to None
|
||||||
:type test_dataloader: Optional[Union[Iterable, Callable]], optional
|
:type test_dataloader: Optional[Union[Iterable, Callable]], optional
|
||||||
:return: (model, train_dataloader, test_dataloader, criterion, optimizer, schedule, lr_scheduler)
|
:return: (engine, train_dataloader, test_dataloader, criterion)
|
||||||
:rtype: tuple
|
:rtype: tuple
|
||||||
'''
|
'''
|
||||||
# initialize distributed environment
|
# initialize distributed environment
|
||||||
|
@ -337,21 +339,7 @@ def initialize(config: Union[str, dict] = None,
|
||||||
optimizer = build_optimizer_wrapper(fp16_cfg, optimizer)
|
optimizer = build_optimizer_wrapper(fp16_cfg, optimizer)
|
||||||
logger.info('Optimizer is created', ranks=[0])
|
logger.info('Optimizer is created', ranks=[0])
|
||||||
|
|
||||||
lr_scheduler = None
|
# build schedule and engine
|
||||||
if hasattr(gpc.config, 'lr_scheduler'):
|
|
||||||
if hasattr(gpc.config, 'num_steps'):
|
|
||||||
total_steps = gpc.config.num_steps
|
|
||||||
elif hasattr(gpc.config, 'num_epochs'):
|
|
||||||
total_steps = int(gpc.config.num_epochs * len(train_dataloader))
|
|
||||||
else:
|
|
||||||
raise Exception(
|
|
||||||
'Please specify training stopping criterion num_steps or num_epochs in your configuration.'
|
|
||||||
)
|
|
||||||
lr_scheduler = build_lr_scheduler(gpc.config.lr_scheduler, optimizer,
|
|
||||||
total_steps, len(train_dataloader))
|
|
||||||
logger.info('Learning rate scheduler is created', ranks=[0])
|
|
||||||
|
|
||||||
# pipeline or no pipeline schedule
|
|
||||||
if hasattr(gpc.config, 'fp16'):
|
if hasattr(gpc.config, 'fp16'):
|
||||||
amp_type = gpc.config.fp16.mode
|
amp_type = gpc.config.fp16.mode
|
||||||
amp_cfg = gpc.config.fp16.copy()
|
amp_cfg = gpc.config.fp16.copy()
|
||||||
|
@ -360,12 +348,32 @@ def initialize(config: Union[str, dict] = None,
|
||||||
amp_type = None
|
amp_type = None
|
||||||
amp_cfg = None
|
amp_cfg = None
|
||||||
|
|
||||||
if gpc.is_initialized(ParallelMode.PIPELINE) and gpc.get_world_size(ParallelMode.PIPELINE) > 1:
|
engine_cfg = gpc.config.get('engine', dict())
|
||||||
assert hasattr(gpc.config,
|
schedule_cfg = engine_cfg.pop('schedule', None)
|
||||||
'schedule'), "Config 'schedule' not found in your configuration file for pipeline parallel training"
|
|
||||||
|
schedule_type = None
|
||||||
|
if schedule_cfg is not None:
|
||||||
|
schedule_type = schedule_cfg.get('type', None)
|
||||||
|
|
||||||
|
if schedule_type is not None:
|
||||||
|
# run customized schedule
|
||||||
|
schedule_cfg['amp_type'] = amp_type
|
||||||
|
schedule_cfg['amp_config'] = amp_cfg
|
||||||
|
schedule = build_schedule(schedule_cfg)
|
||||||
|
elif gpc.is_initialized(ParallelMode.PIPELINE) and gpc.get_world_size(ParallelMode.PIPELINE) > 1:
|
||||||
|
assert schedule_cfg is not None, \
|
||||||
|
"Config 'engine.schedule' not found in your configuration file for pipeline parallel training"
|
||||||
schedule = PipelineSchedule(
|
schedule = PipelineSchedule(
|
||||||
amp_type=amp_type, amp_config=amp_cfg, **gpc.config.schedule.copy())
|
amp_type=amp_type, amp_config=amp_cfg, **schedule_cfg.copy())
|
||||||
else:
|
else:
|
||||||
schedule = NoPipelineSchedule(amp_type=amp_type, amp_config=amp_cfg)
|
schedule = NoPipelineSchedule(amp_type=amp_type, amp_config=amp_cfg)
|
||||||
|
|
||||||
return model, train_dataloader, test_dataloader, criterion, optimizer, schedule, lr_scheduler
|
engine = Engine(
|
||||||
|
model=model,
|
||||||
|
optimizer=optimizer,
|
||||||
|
criterion=criterion,
|
||||||
|
step_schedule=schedule,
|
||||||
|
**gpc.config.get('engine', dict())
|
||||||
|
)
|
||||||
|
|
||||||
|
return engine, train_dataloader, test_dataloader
|
||||||
|
|
|
@ -7,6 +7,7 @@ from torch import Tensor
|
||||||
from colossalai.context.parallel_mode import ParallelMode
|
from colossalai.context.parallel_mode import ParallelMode
|
||||||
from colossalai.core import global_context as gpc
|
from colossalai.core import global_context as gpc
|
||||||
from colossalai.utils import get_current_device
|
from colossalai.utils import get_current_device
|
||||||
|
from torch.cuda.amp import custom_bwd, custom_fwd
|
||||||
|
|
||||||
|
|
||||||
def matmul_2d(a,
|
def matmul_2d(a,
|
||||||
|
@ -60,6 +61,7 @@ class Matmul_AB_2D(torch.autograd.Function):
|
||||||
"""Matrix multiplication for :math:`C = AB`
|
"""Matrix multiplication for :math:`C = AB`
|
||||||
"""
|
"""
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@custom_fwd(cast_inputs=torch.float16)
|
||||||
def forward(ctx: Any,
|
def forward(ctx: Any,
|
||||||
A: Tensor,
|
A: Tensor,
|
||||||
B: Tensor,
|
B: Tensor,
|
||||||
|
@ -120,32 +122,32 @@ class Matmul_AB_2D(torch.autograd.Function):
|
||||||
return out
|
return out
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@custom_bwd
|
||||||
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
|
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
|
||||||
A, B = ctx.saved_tensors
|
A, B = ctx.saved_tensors
|
||||||
A_grad = Matmul_ABT_2D.forward(
|
with torch.no_grad():
|
||||||
None,
|
A_grad = Matmul_ABT_2D.apply(
|
||||||
output_grad, B,
|
output_grad, B,
|
||||||
ctx.summa_dim, ctx.A_shape,
|
ctx.summa_dim, ctx.A_shape,
|
||||||
ctx.row_rank, ctx.col_rank,
|
ctx.row_rank, ctx.col_rank,
|
||||||
ctx.row_parallel_mode,
|
ctx.row_parallel_mode,
|
||||||
ctx.col_parallel_mode,
|
ctx.col_parallel_mode,
|
||||||
ctx.data_parallel_rank,
|
ctx.data_parallel_rank,
|
||||||
ctx.pipeline_parallel_rank,
|
ctx.pipeline_parallel_rank,
|
||||||
ctx.pipeline_parallel_size,
|
ctx.pipeline_parallel_size,
|
||||||
ctx.tensor_parallel_size
|
ctx.tensor_parallel_size
|
||||||
)
|
)
|
||||||
B_grad = Matmul_ATB_2D.forward(
|
B_grad = Matmul_ATB_2D.apply(
|
||||||
None,
|
A, output_grad,
|
||||||
A, output_grad,
|
ctx.summa_dim, ctx.B_shape,
|
||||||
ctx.summa_dim, ctx.B_shape,
|
ctx.row_rank, ctx.col_rank,
|
||||||
ctx.row_rank, ctx.col_rank,
|
ctx.row_parallel_mode,
|
||||||
ctx.row_parallel_mode,
|
ctx.col_parallel_mode,
|
||||||
ctx.col_parallel_mode,
|
ctx.data_parallel_rank,
|
||||||
ctx.data_parallel_rank,
|
ctx.pipeline_parallel_rank,
|
||||||
ctx.pipeline_parallel_rank,
|
ctx.pipeline_parallel_size,
|
||||||
ctx.pipeline_parallel_size,
|
ctx.tensor_parallel_size
|
||||||
ctx.tensor_parallel_size
|
)
|
||||||
)
|
|
||||||
return A_grad, B_grad, None, None, None, None, None, None, None, None, None, None
|
return A_grad, B_grad, None, None, None, None, None, None, None, None, None, None
|
||||||
|
|
||||||
|
|
||||||
|
@ -153,6 +155,7 @@ class Matmul_ABT_2D(torch.autograd.Function):
|
||||||
"""Matrix multiplication for :math:`C = AB^T`
|
"""Matrix multiplication for :math:`C = AB^T`
|
||||||
"""
|
"""
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@custom_fwd(cast_inputs=torch.float16)
|
||||||
def forward(ctx: Any,
|
def forward(ctx: Any,
|
||||||
A: Tensor,
|
A: Tensor,
|
||||||
B: Tensor,
|
B: Tensor,
|
||||||
|
@ -214,32 +217,33 @@ class Matmul_ABT_2D(torch.autograd.Function):
|
||||||
return out
|
return out
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@custom_bwd
|
||||||
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
|
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
|
||||||
A, B = ctx.saved_tensors
|
A, B = ctx.saved_tensors
|
||||||
A_grad = Matmul_AB_2D.forward(
|
|
||||||
None,
|
with torch.no_grad():
|
||||||
output_grad, B,
|
A_grad = Matmul_AB_2D.apply(
|
||||||
ctx.summa_dim, ctx.A_shape,
|
output_grad, B,
|
||||||
ctx.row_rank, ctx.col_rank,
|
ctx.summa_dim, ctx.A_shape,
|
||||||
ctx.row_parallel_mode,
|
ctx.row_rank, ctx.col_rank,
|
||||||
ctx.col_parallel_mode,
|
ctx.row_parallel_mode,
|
||||||
ctx.data_parallel_rank,
|
ctx.col_parallel_mode,
|
||||||
ctx.pipeline_parallel_rank,
|
ctx.data_parallel_rank,
|
||||||
ctx.pipeline_parallel_size,
|
ctx.pipeline_parallel_rank,
|
||||||
ctx.tensor_parallel_size
|
ctx.pipeline_parallel_size,
|
||||||
)
|
ctx.tensor_parallel_size
|
||||||
B_grad = Matmul_ATB_2D.forward(
|
)
|
||||||
None,
|
B_grad = Matmul_ATB_2D.apply(
|
||||||
output_grad, A,
|
output_grad, A,
|
||||||
ctx.summa_dim, ctx.B_shape,
|
ctx.summa_dim, ctx.B_shape,
|
||||||
ctx.row_rank, ctx.col_rank,
|
ctx.row_rank, ctx.col_rank,
|
||||||
ctx.row_parallel_mode,
|
ctx.row_parallel_mode,
|
||||||
ctx.col_parallel_mode,
|
ctx.col_parallel_mode,
|
||||||
ctx.data_parallel_rank,
|
ctx.data_parallel_rank,
|
||||||
ctx.pipeline_parallel_rank,
|
ctx.pipeline_parallel_rank,
|
||||||
ctx.pipeline_parallel_size,
|
ctx.pipeline_parallel_size,
|
||||||
ctx.tensor_parallel_size
|
ctx.tensor_parallel_size
|
||||||
)
|
)
|
||||||
return A_grad, B_grad, None, None, None, None, None, None, None, None, None, None
|
return A_grad, B_grad, None, None, None, None, None, None, None, None, None, None
|
||||||
|
|
||||||
|
|
||||||
|
@ -247,6 +251,7 @@ class Matmul_ATB_2D(torch.autograd.Function):
|
||||||
"""Matrix multiplication for :math:`C = A^TB`
|
"""Matrix multiplication for :math:`C = A^TB`
|
||||||
"""
|
"""
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@custom_fwd(cast_inputs=torch.float16)
|
||||||
def forward(ctx: Any,
|
def forward(ctx: Any,
|
||||||
A: Tensor,
|
A: Tensor,
|
||||||
B: Tensor,
|
B: Tensor,
|
||||||
|
@ -308,32 +313,33 @@ class Matmul_ATB_2D(torch.autograd.Function):
|
||||||
return out
|
return out
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@custom_bwd
|
||||||
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
|
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
|
||||||
A, B = ctx.saved_tensors
|
A, B = ctx.saved_tensors
|
||||||
A_grad = Matmul_ABT_2D.forward(
|
|
||||||
None,
|
with torch.no_grad():
|
||||||
B, output_grad,
|
A_grad = Matmul_ABT_2D.apply(
|
||||||
ctx.summa_dim, ctx.A_shape,
|
B, output_grad,
|
||||||
ctx.row_rank, ctx.col_rank,
|
ctx.summa_dim, ctx.A_shape,
|
||||||
ctx.row_parallel_mode,
|
ctx.row_rank, ctx.col_rank,
|
||||||
ctx.col_parallel_mode,
|
ctx.row_parallel_mode,
|
||||||
ctx.data_parallel_rank,
|
ctx.col_parallel_mode,
|
||||||
ctx.pipeline_parallel_rank,
|
ctx.data_parallel_rank,
|
||||||
ctx.pipeline_parallel_size,
|
ctx.pipeline_parallel_rank,
|
||||||
ctx.tensor_parallel_size
|
ctx.pipeline_parallel_size,
|
||||||
)
|
ctx.tensor_parallel_size
|
||||||
B_grad = Matmul_AB_2D.forward(
|
)
|
||||||
None,
|
B_grad = Matmul_AB_2D.apply(
|
||||||
A, output_grad,
|
A, output_grad,
|
||||||
ctx.summa_dim, ctx.B_shape,
|
ctx.summa_dim, ctx.B_shape,
|
||||||
ctx.row_rank, ctx.col_rank,
|
ctx.row_rank, ctx.col_rank,
|
||||||
ctx.row_parallel_mode,
|
ctx.row_parallel_mode,
|
||||||
ctx.col_parallel_mode,
|
ctx.col_parallel_mode,
|
||||||
ctx.data_parallel_rank,
|
ctx.data_parallel_rank,
|
||||||
ctx.pipeline_parallel_rank,
|
ctx.pipeline_parallel_rank,
|
||||||
ctx.pipeline_parallel_size,
|
ctx.pipeline_parallel_size,
|
||||||
ctx.tensor_parallel_size
|
ctx.tensor_parallel_size
|
||||||
)
|
)
|
||||||
return A_grad, B_grad, None, None, None, None, None, None, None, None, None, None
|
return A_grad, B_grad, None, None, None, None, None, None, None, None, None, None
|
||||||
|
|
||||||
|
|
||||||
|
@ -341,6 +347,7 @@ class Add_Bias_2D(torch.autograd.Function):
|
||||||
"""Matrix add bias: :math:`C = A + b`
|
"""Matrix add bias: :math:`C = A + b`
|
||||||
"""
|
"""
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@custom_fwd(cast_inputs=torch.float16)
|
||||||
def forward(ctx: Any,
|
def forward(ctx: Any,
|
||||||
input: Tensor,
|
input: Tensor,
|
||||||
bias: Tensor,
|
bias: Tensor,
|
||||||
|
@ -384,6 +391,7 @@ class Add_Bias_2D(torch.autograd.Function):
|
||||||
return output
|
return output
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@custom_bwd
|
||||||
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
|
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
|
||||||
row_rank = ctx.row_rank
|
row_rank = ctx.row_rank
|
||||||
col_rank = ctx.col_rank
|
col_rank = ctx.col_rank
|
||||||
|
@ -423,6 +431,7 @@ class Add_Bias_2D(torch.autograd.Function):
|
||||||
class _LayerNorm_2D(torch.autograd.Function):
|
class _LayerNorm_2D(torch.autograd.Function):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@custom_fwd(cast_inputs=torch.float32)
|
||||||
def forward(ctx: Any,
|
def forward(ctx: Any,
|
||||||
input: Tensor,
|
input: Tensor,
|
||||||
E_x: Tensor,
|
E_x: Tensor,
|
||||||
|
@ -440,6 +449,7 @@ class _LayerNorm_2D(torch.autograd.Function):
|
||||||
return output
|
return output
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@custom_bwd
|
||||||
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
|
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
|
||||||
row_parallel_mode = ctx.row_parallel_mode
|
row_parallel_mode = ctx.row_parallel_mode
|
||||||
col_parallel_mode = ctx.col_parallel_mode
|
col_parallel_mode = ctx.col_parallel_mode
|
||||||
|
@ -492,6 +502,7 @@ class _LayerNorm_2D(torch.autograd.Function):
|
||||||
class _ViT_Split_Input_2D(torch.autograd.Function):
|
class _ViT_Split_Input_2D(torch.autograd.Function):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@custom_fwd(cast_inputs=torch.float16)
|
||||||
def forward(ctx: Any,
|
def forward(ctx: Any,
|
||||||
inputs: Tensor,
|
inputs: Tensor,
|
||||||
batch_size: int,
|
batch_size: int,
|
||||||
|
@ -509,6 +520,7 @@ class _ViT_Split_Input_2D(torch.autograd.Function):
|
||||||
return output
|
return output
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@custom_bwd
|
||||||
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
|
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
|
||||||
# output_grad: [b/q, s, h/q]
|
# output_grad: [b/q, s, h/q]
|
||||||
# grads: [b, s, h/q]
|
# grads: [b, s, h/q]
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
from .cosine import CosineAnnealingLR, CosineAnnealingWarmupLR, FlatAnnealingLR, FlatAnnealingWarmupLR
|
from .cosine import CosineAnnealingLR, CosineAnnealingWarmupLR, FlatAnnealingLR, FlatAnnealingWarmupLR
|
||||||
from .linear import LinearWarmupLR, LinearWarmupDecay
|
from .linear import LinearWarmupLR
|
||||||
from .multistep import MultiStepLR, MultiStepWarmupLR
|
from .multistep import MultiStepLR, MultiStepWarmupLR
|
||||||
from .onecycle import OneCycleLR
|
from .onecycle import OneCycleLR
|
||||||
from .poly import PolynomialLR, PolynomialWarmupLR
|
from .poly import PolynomialLR, PolynomialWarmupLR
|
||||||
|
|
|
@ -66,11 +66,10 @@ class CosineAnnealingWarmupLR(WarmupScheduler):
|
||||||
:type last_epoch: int, optional
|
:type last_epoch: int, optional
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, optimizer, total_steps: int, warmup_steps: int = 0, eta_min: int = 0, last_epoch: int = -1,
|
def __init__(self, optimizer, total_steps: int, warmup_steps: int = 0, eta_min: int = 0, last_epoch: int = -1):
|
||||||
**kwargs):
|
|
||||||
base_scheduler = _CosineAnnealingLR(
|
base_scheduler = _CosineAnnealingLR(
|
||||||
optimizer, total_steps - warmup_steps, eta_min=eta_min)
|
optimizer, total_steps - warmup_steps, eta_min=eta_min, last_epoch=last_epoch)
|
||||||
super().__init__(optimizer, warmup_steps, base_scheduler, last_epoch=last_epoch)
|
super().__init__(optimizer, warmup_steps, base_scheduler)
|
||||||
|
|
||||||
|
|
||||||
@LR_SCHEDULERS.register_module
|
@LR_SCHEDULERS.register_module
|
||||||
|
|
|
@ -55,7 +55,7 @@ class DelayerScheduler(_LRScheduler):
|
||||||
|
|
||||||
|
|
||||||
class WarmupScheduler(_LRScheduler):
|
class WarmupScheduler(_LRScheduler):
|
||||||
""" Starts with a linear warmup lr schedule until it reaches N epochs the applies a scheduler
|
""" Starts with a linear warmup lr schedule until it reaches N epochs the applies a scheduler
|
||||||
|
|
||||||
:param optimizer: Wrapped optimizer.
|
:param optimizer: Wrapped optimizer.
|
||||||
:type optimizer: torch.optim.Optimizer
|
:type optimizer: torch.optim.Optimizer
|
||||||
|
@ -66,11 +66,8 @@ class WarmupScheduler(_LRScheduler):
|
||||||
:param last_epoch: The index of last epoch, defaults to -1
|
:param last_epoch: The index of last epoch, defaults to -1
|
||||||
:type last_epoch: int, optional
|
:type last_epoch: int, optional
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, optimizer, warmup_epochs, after_scheduler, last_epoch=-1):
|
def __init__(self, optimizer, warmup_epochs, after_scheduler, last_epoch=-1):
|
||||||
if warmup_epochs < 0:
|
self.warmup_epochs = int(warmup_epochs)
|
||||||
raise ValueError(f'warmup_epochs must >= 0, got {warmup_epochs}')
|
|
||||||
self.warmup_epochs = warmup_epochs
|
|
||||||
self.after_scheduler = after_scheduler
|
self.after_scheduler = after_scheduler
|
||||||
self.finished = False
|
self.finished = False
|
||||||
super().__init__(optimizer, last_epoch)
|
super().__init__(optimizer, last_epoch)
|
||||||
|
@ -79,14 +76,10 @@ class WarmupScheduler(_LRScheduler):
|
||||||
if self.last_epoch >= self.warmup_epochs:
|
if self.last_epoch >= self.warmup_epochs:
|
||||||
if not self.finished:
|
if not self.finished:
|
||||||
self.after_scheduler.base_lrs = self.base_lrs
|
self.after_scheduler.base_lrs = self.base_lrs
|
||||||
# reset lr to base_lr
|
|
||||||
for group, base_lr in zip(self.optimizer.param_groups, self.base_lrs):
|
|
||||||
group['lr'] = base_lr
|
|
||||||
self.finished = True
|
self.finished = True
|
||||||
with _enable_get_lr_call(self.after_scheduler):
|
return self.after_scheduler.get_lr()
|
||||||
return self.after_scheduler.get_lr()
|
|
||||||
|
|
||||||
return [(self.last_epoch + 1) / (self.warmup_epochs + 1) * lr for lr in self.base_lrs]
|
return [(self.last_epoch + 1) / self.warmup_epochs * lr for lr in self.base_lrs]
|
||||||
|
|
||||||
def step(self, epoch=None):
|
def step(self, epoch=None):
|
||||||
if self.finished:
|
if self.finished:
|
||||||
|
|
|
@ -28,18 +28,3 @@ class LinearWarmupLR(_LRScheduler):
|
||||||
else:
|
else:
|
||||||
return [(self.total_steps - self.last_epoch) / (self.total_steps - self.warmup_steps) * lr for lr in
|
return [(self.total_steps - self.last_epoch) / (self.total_steps - self.warmup_steps) * lr for lr in
|
||||||
self.base_lrs]
|
self.base_lrs]
|
||||||
|
|
||||||
|
|
||||||
@LR_SCHEDULERS.register_module
|
|
||||||
class LinearWarmupDecay(_LRScheduler):
|
|
||||||
def __init__(self, optimizer, total_steps: int, warmup_steps: int = 0, last_epoch: int = -1, **kwargs):
|
|
||||||
self.warmup_steps = int(warmup_steps)
|
|
||||||
self.total_steps = total_steps
|
|
||||||
super().__init__(optimizer, last_epoch=last_epoch)
|
|
||||||
|
|
||||||
def get_lr(self):
|
|
||||||
if self.last_epoch < self.warmup_steps:
|
|
||||||
return [(self.last_epoch + 1) / self.warmup_steps * lr for lr in self.base_lrs]
|
|
||||||
else:
|
|
||||||
return [(self.total_steps - self.last_epoch - 1) / (self.total_steps - self.warmup_steps) * lr for lr in
|
|
||||||
self.base_lrs]
|
|
||||||
|
|
|
@ -27,12 +27,7 @@ class MultiStepLR(_MultiStepLR):
|
||||||
:type last_epoch: int, optional
|
:type last_epoch: int, optional
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, optimizer, total_steps: int, milestones: List[int] = None, gamma: float = 0.1,
|
def __init__(self, optimizer, total_steps: int, milestones: List[int] = None, gamma: float = 0.1, last_epoch: int = -1, **kwargs):
|
||||||
num_steps_per_epoch: int = -1, last_epoch: int = -1, **kwargs):
|
|
||||||
if num_steps_per_epoch <= 0:
|
|
||||||
raise ValueError(
|
|
||||||
f'num_steps_per_epoch must > 0, got {num_steps_per_epoch}')
|
|
||||||
milestones = [v * num_steps_per_epoch for v in milestones]
|
|
||||||
super().__init__(optimizer, milestones, gamma=gamma, last_epoch=last_epoch)
|
super().__init__(optimizer, milestones, gamma=gamma, last_epoch=last_epoch)
|
||||||
|
|
||||||
|
|
||||||
|
@ -57,14 +52,11 @@ class MultiStepWarmupLR(WarmupScheduler):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, optimizer, total_steps: int, warmup_steps: int = 0, milestones: List[int] = None,
|
def __init__(self, optimizer, total_steps: int, warmup_steps: int = 0, milestones: List[int] = None,
|
||||||
gamma: float = 0.1, num_steps_per_epoch: int = -1, last_epoch: int = -1, **kwargs):
|
gamma: float = 0.1, last_epoch: int = -1, **kwargs):
|
||||||
if len(milestones) == 0:
|
if len(milestones) == 0:
|
||||||
raise ValueError('milestones cannot be empty')
|
raise ValueError('milestones cannot be empty')
|
||||||
if num_steps_per_epoch <= 0:
|
milestones = [
|
||||||
raise ValueError(
|
v - warmup_steps for v in milestones if v >= warmup_steps]
|
||||||
f'num_steps_per_epoch must > 0, got {num_steps_per_epoch}')
|
|
||||||
milestones = [v * num_steps_per_epoch - warmup_steps for v in milestones if v *
|
|
||||||
num_steps_per_epoch >= warmup_steps]
|
|
||||||
base_scheduler = _MultiStepLR(optimizer, milestones=milestones,
|
base_scheduler = _MultiStepLR(optimizer, milestones=milestones,
|
||||||
gamma=gamma)
|
gamma=gamma)
|
||||||
super().__init__(optimizer, warmup_steps, base_scheduler, last_epoch=last_epoch)
|
super().__init__(optimizer, warmup_steps, base_scheduler, last_epoch=last_epoch)
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
from torch.optim.lr_scheduler import LambdaLR as _LambdaLR
|
from torch.optim.lr_scheduler import LambdaLR as _LambdaLR
|
||||||
from torch.optim.lr_scheduler import MultiplicativeLR as _MultiplicativeLR
|
from torch.optim.lr_scheduler import MultiplicativeLR as _MultiplicativeLR
|
||||||
from torch.optim.lr_scheduler import StepLR as _StepLR
|
from torch.optim.lr_scheduler import StepLR as _StepLR
|
||||||
from torch.optim.lr_scheduler import _LRScheduler
|
from torch.optim.lr_scheduler import ExponentialLR as _ExponentialLR
|
||||||
|
|
||||||
from colossalai.registry import LR_SCHEDULERS
|
from colossalai.registry import LR_SCHEDULERS
|
||||||
|
|
||||||
|
@ -25,11 +25,8 @@ class LambdaLR(_LambdaLR):
|
||||||
:type last_epoch: int, optional
|
:type last_epoch: int, optional
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, optimizer, total_steps, lr_lambda=None, num_steps_per_epoch: int = -1,
|
def __init__(self, optimizer, total_steps, lr_lambda=None, last_epoch: int = -1) -> None:
|
||||||
last_epoch: int = -1) -> None:
|
super().__init__(optimizer, lr_lambda, last_epoch=last_epoch)
|
||||||
def func(step): return lr_lambda(step // num_steps_per_epoch)
|
|
||||||
|
|
||||||
super().__init__(optimizer, func, last_epoch=last_epoch)
|
|
||||||
|
|
||||||
|
|
||||||
@LR_SCHEDULERS.register_module
|
@LR_SCHEDULERS.register_module
|
||||||
|
@ -51,11 +48,8 @@ class MultiplicativeLR(_MultiplicativeLR):
|
||||||
:type last_epoch: int, optional
|
:type last_epoch: int, optional
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, optimizer, total_steps, lr_lambda=None, num_steps_per_epoch: int = -1,
|
def __init__(self, optimizer, total_steps, lr_lambda=None, last_epoch: int = -1) -> None:
|
||||||
last_epoch: int = -1) -> None:
|
super().__init__(optimizer, lr_lambda, last_epoch=last_epoch)
|
||||||
def func(step): return lr_lambda(step // num_steps_per_epoch)
|
|
||||||
|
|
||||||
super().__init__(optimizer, func, last_epoch=last_epoch)
|
|
||||||
|
|
||||||
|
|
||||||
@LR_SCHEDULERS.register_module
|
@LR_SCHEDULERS.register_module
|
||||||
|
@ -79,14 +73,13 @@ class StepLR(_StepLR):
|
||||||
:type last_epoch: int, optional
|
:type last_epoch: int, optional
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, optimizer, total_steps, step_size: int = 1, gamma: float = 0.1, num_steps_per_epoch: int = -1,
|
def __init__(self, optimizer, total_steps, step_size: int = 1, gamma: float = 0.1, last_epoch: int = -1) -> None:
|
||||||
last_epoch: int = -1) -> None:
|
super().__init__(optimizer, step_size,
|
||||||
super().__init__(optimizer, step_size * num_steps_per_epoch,
|
|
||||||
gamma=gamma, last_epoch=last_epoch)
|
gamma=gamma, last_epoch=last_epoch)
|
||||||
|
|
||||||
|
|
||||||
@LR_SCHEDULERS.register_module
|
@LR_SCHEDULERS.register_module
|
||||||
class ExponentialLR(_LRScheduler):
|
class ExponentialLR(_ExponentialLR):
|
||||||
"""Decays the learning rate of each parameter group by gamma every epoch.
|
"""Decays the learning rate of each parameter group by gamma every epoch.
|
||||||
When last_epoch=-1, sets initial lr as lr
|
When last_epoch=-1, sets initial lr as lr
|
||||||
|
|
||||||
|
@ -102,21 +95,6 @@ class ExponentialLR(_LRScheduler):
|
||||||
:type last_epoch: int, optional
|
:type last_epoch: int, optional
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, optimizer, total_steps, gamma: float = 1.0, num_steps_per_epoch: int = -1,
|
def __init__(self, optimizer, total_steps, gamma: float = 1.0,
|
||||||
last_epoch: int = -1) -> None:
|
last_epoch: int = -1) -> None:
|
||||||
self.gamma = gamma
|
super().__init__(optimizer, gamma, last_epoch=last_epoch)
|
||||||
self.num_steps_per_epoch = num_steps_per_epoch
|
|
||||||
super().__init__(optimizer, last_epoch=last_epoch)
|
|
||||||
|
|
||||||
def get_lr(self):
|
|
||||||
if self.last_epoch == 0:
|
|
||||||
return self.base_lrs
|
|
||||||
elif (self.last_epoch + 1) % self.num_steps_per_epoch == 0:
|
|
||||||
return [group['lr'] * self.gamma
|
|
||||||
for group in self.optimizer.param_groups]
|
|
||||||
return [group['lr']
|
|
||||||
for group in self.optimizer.param_groups]
|
|
||||||
|
|
||||||
def _get_closed_form_lr(self):
|
|
||||||
return [base_lr * self.gamma ** (self.last_epoch // self.num_steps_per_epoch)
|
|
||||||
for base_lr in self.base_lrs]
|
|
||||||
|
|
|
@ -106,7 +106,7 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
|
||||||
tensor_parallel_norm = _calc_lp(tensor_parallel_grads, norm_type)
|
tensor_parallel_norm = _calc_lp(tensor_parallel_grads, norm_type)
|
||||||
no_tensor_parallel_grads = _calc_lp(
|
no_tensor_parallel_grads = _calc_lp(
|
||||||
no_tensor_parallel_grads, norm_type)
|
no_tensor_parallel_grads, norm_type)
|
||||||
if gpc.is_initialized(ParallelMode.TENSOR):
|
if gpc.is_initialized(ParallelMode.TENSOR) and len(tensor_parallel_grads) > 0:
|
||||||
# Sum across all model-parallel GPUs.
|
# Sum across all model-parallel GPUs.
|
||||||
torch.distributed.all_reduce(tensor_parallel_norm,
|
torch.distributed.all_reduce(tensor_parallel_norm,
|
||||||
op=torch.distributed.ReduceOp.SUM,
|
op=torch.distributed.ReduceOp.SUM,
|
||||||
|
|
|
@ -6,6 +6,7 @@ import math
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from deepspeed.git_version_info import version
|
from deepspeed.git_version_info import version
|
||||||
from deepspeed.moe.utils import is_moe_param
|
from deepspeed.moe.utils import is_moe_param
|
||||||
|
@ -13,7 +14,7 @@ try:
|
||||||
from deepspeed.ops.op_builder import UtilsBuilder
|
from deepspeed.ops.op_builder import UtilsBuilder
|
||||||
from deepspeed.runtime.zero.config import ZERO_OPTIMIZATION_GRADIENTS
|
from deepspeed.runtime.zero.config import ZERO_OPTIMIZATION_GRADIENTS
|
||||||
except ImportError:
|
except ImportError:
|
||||||
print('DeepSpeed is required if you want to use ZeRO.')
|
pass
|
||||||
from packaging import version as pkg_version
|
from packaging import version as pkg_version
|
||||||
from torch._six import inf
|
from torch._six import inf
|
||||||
from torch.distributed.distributed_c10d import _get_global_rank
|
from torch.distributed.distributed_c10d import _get_global_rank
|
||||||
|
@ -251,7 +252,7 @@ class ZeroRedundancyOptimizer_Level_2(Optimizer):
|
||||||
self.nccl_start_alignment_factor = 2
|
self.nccl_start_alignment_factor = 2
|
||||||
|
|
||||||
assert (
|
assert (
|
||||||
allgather_bucket_size % self.nccl_start_alignment_factor == 0), f"allgather_bucket_size must be a multiple of nccl_start_alignment_factor, {self.nccl_start_alignment_factor} "
|
allgather_bucket_size % self.nccl_start_alignment_factor == 0), f"allgather_bucket_size must be a multiple of nccl_start_alignment_factor, {self.nccl_start_alignment_factor} "
|
||||||
|
|
||||||
self.all_reduce_print = False
|
self.all_reduce_print = False
|
||||||
self.dtype = self.optimizer.param_groups[0]['params'][0].dtype
|
self.dtype = self.optimizer.param_groups[0]['params'][0].dtype
|
||||||
|
@ -759,7 +760,7 @@ class ZeroRedundancyOptimizer_Level_2(Optimizer):
|
||||||
elif start_index > current_index and start_index < (current_index +
|
elif start_index > current_index and start_index < (current_index +
|
||||||
param_size):
|
param_size):
|
||||||
assert (
|
assert (
|
||||||
first_offset == 0), "This can happen either zero or only once as this must be the first tensor in the partition"
|
first_offset == 0), "This can happen either zero or only once as this must be the first tensor in the partition"
|
||||||
first_offset = start_index - current_index
|
first_offset = start_index - current_index
|
||||||
|
|
||||||
set_key_value_list(self.param_to_partition_ids[i],
|
set_key_value_list(self.param_to_partition_ids[i],
|
||||||
|
@ -803,7 +804,7 @@ class ZeroRedundancyOptimizer_Level_2(Optimizer):
|
||||||
def report_ipg_memory_usage(self, tag, param_elems):
|
def report_ipg_memory_usage(self, tag, param_elems):
|
||||||
elem_count = self.elements_in_ipg_bucket + param_elems
|
elem_count = self.elements_in_ipg_bucket + param_elems
|
||||||
percent_of_bucket_size = (
|
percent_of_bucket_size = (
|
||||||
100.0 * elem_count) // self.reduce_bucket_size
|
100.0 * elem_count) // self.reduce_bucket_size
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
report_memory_usage(
|
report_memory_usage(
|
||||||
f"{tag}: elems in_bucket {self.elements_in_ipg_bucket} param {param_elems} max_percent {percent_of_bucket_size}"
|
f"{tag}: elems in_bucket {self.elements_in_ipg_bucket} param {param_elems} max_percent {percent_of_bucket_size}"
|
||||||
|
@ -1491,7 +1492,7 @@ class ZeroRedundancyOptimizer_Level_2(Optimizer):
|
||||||
params_in_partition.append(tensor)
|
params_in_partition.append(tensor)
|
||||||
|
|
||||||
assert (
|
assert (
|
||||||
first_offset == 0), "This can happen either zero or only once as this must be the first tensor in the partition"
|
first_offset == 0), "This can happen either zero or only once as this must be the first tensor in the partition"
|
||||||
first_offset = start_index - current_index
|
first_offset = start_index - current_index
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
@ -1799,7 +1800,7 @@ class ZeroRedundancyOptimizer_Level_2(Optimizer):
|
||||||
num_elements = shard_size
|
num_elements = shard_size
|
||||||
|
|
||||||
assert shard_size * \
|
assert shard_size * \
|
||||||
num_shards <= partitioned_params[partition_id].numel()
|
num_shards <= partitioned_params[partition_id].numel()
|
||||||
|
|
||||||
for shard_id in range(num_shards):
|
for shard_id in range(num_shards):
|
||||||
|
|
||||||
|
@ -2248,7 +2249,7 @@ def estimate_zero2_model_states_mem_needs(total_params,
|
||||||
if cpu_offload:
|
if cpu_offload:
|
||||||
gpu_mem = 2 * total_params
|
gpu_mem = 2 * total_params
|
||||||
cpu_mem = total_params * \
|
cpu_mem = total_params * \
|
||||||
max(4 * total_gpus, 16) * additional_buffer_factor
|
max(4 * total_gpus, 16) * additional_buffer_factor
|
||||||
else:
|
else:
|
||||||
gpu_mem = 4 * total_params + int(16 * total_params / total_gpus)
|
gpu_mem = 4 * total_params + int(16 * total_params / total_gpus)
|
||||||
cpu_mem = total_params * 4 * num_gpus_per_node * additional_buffer_factor
|
cpu_mem = total_params * 4 * num_gpus_per_node * additional_buffer_factor
|
||||||
|
|
|
@ -21,7 +21,7 @@ try:
|
||||||
from deepspeed.runtime.zero.partition_parameters import *
|
from deepspeed.runtime.zero.partition_parameters import *
|
||||||
from deepspeed.runtime.zero.partition_parameters import _init_external_params
|
from deepspeed.runtime.zero.partition_parameters import _init_external_params
|
||||||
except ImportError:
|
except ImportError:
|
||||||
print('DeepSpeed is required if you want to use ZeRO.')
|
pass
|
||||||
|
|
||||||
from torch._six import inf
|
from torch._six import inf
|
||||||
from torch.distributed.distributed_c10d import _get_global_rank
|
from torch.distributed.distributed_c10d import _get_global_rank
|
||||||
|
|
|
@ -20,3 +20,4 @@ TRANSFORMS = Registry('transforms', third_party_library=[transforms])
|
||||||
PIPE_ALLOC_POLICY = Registry('pipeline_allocation_policy')
|
PIPE_ALLOC_POLICY = Registry('pipeline_allocation_policy')
|
||||||
SAMPLERS = Registry('samplers')
|
SAMPLERS = Registry('samplers')
|
||||||
LR_SCHEDULERS = Registry('lr_schedulers')
|
LR_SCHEDULERS = Registry('lr_schedulers')
|
||||||
|
SCHEDULE = Registry('schedules')
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
from ._trainer import Trainer
|
from ._trainer import Trainer
|
||||||
from .hooks import *
|
from .hooks import *
|
||||||
from .metric import Loss, Accuracy2D, Accuracy3D, Accuracy2p5D
|
from .metric import Loss, Accuracy2D, Accuracy3D, Accuracy2p5D, LearningRate
|
||||||
|
|
||||||
__all__ = ['Trainer', 'Loss', 'Accuracy3D', 'Accuracy2D', 'Accuracy2p5D']
|
__all__ = ['Trainer', 'Loss', 'Accuracy3D', 'Accuracy2D', 'Accuracy2p5D', 'LearningRate']
|
||||||
|
|
|
@ -1,7 +1,6 @@
|
||||||
#!/usr/bin/env python
|
#!/usr/bin/env python
|
||||||
# -*- encoding: utf-8 -*-
|
# -*- encoding: utf-8 -*-
|
||||||
|
|
||||||
from typing import Optional
|
|
||||||
from typing import Union, List
|
from typing import Union, List
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
@ -10,12 +9,11 @@ from torch.utils.data import DataLoader
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from colossalai.builder import build_hooks
|
from colossalai.builder import build_hooks
|
||||||
from colossalai.checkpointing import save_checkpoint, load_checkpoint, get_checkpoint_path
|
|
||||||
from colossalai.context import Config
|
|
||||||
from colossalai.engine import Engine
|
from colossalai.engine import Engine
|
||||||
from colossalai.logging import get_global_dist_logger
|
from colossalai.logging import get_global_dist_logger
|
||||||
from colossalai.utils import get_global_multitimer, is_dp_rank_0, is_tp_rank_0, is_no_pp_or_last_stage
|
|
||||||
from colossalai.nn.data import DataParallelSampler
|
from colossalai.nn.data import DataParallelSampler
|
||||||
|
from colossalai.utils import MultiTimer
|
||||||
|
from colossalai.utils import is_dp_rank_0, is_tp_rank_0, is_no_pp_or_last_stage
|
||||||
|
|
||||||
|
|
||||||
class Trainer:
|
class Trainer:
|
||||||
|
@ -30,43 +28,31 @@ class Trainer:
|
||||||
:type hoooks_cfg: Config, optional
|
:type hoooks_cfg: Config, optional
|
||||||
:type verbose: bool, optional
|
:type verbose: bool, optional
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
engine: Engine,
|
engine: Engine,
|
||||||
hooks_cfg: Optional[Config] = None,
|
verbose: bool = False,
|
||||||
verbose: bool = False):
|
timer: MultiTimer = None):
|
||||||
# training-ralated params
|
# training-ralated params
|
||||||
self._engine = engine
|
self._engine = engine
|
||||||
self._max_epochs = float('inf')
|
self._max_epochs = 0
|
||||||
self._max_steps = float('inf')
|
|
||||||
self._cur_epoch = 0
|
self._cur_epoch = 0
|
||||||
|
self._max_steps = 0
|
||||||
self._cur_step = 0
|
self._cur_step = 0
|
||||||
|
self._steps_per_epoch = 0
|
||||||
# data-related params
|
|
||||||
self._train_dataloader = None
|
|
||||||
self._test_dataloader = None
|
|
||||||
|
|
||||||
# misc params
|
# misc params
|
||||||
self._display_progress = False
|
|
||||||
self._logger = get_global_dist_logger()
|
self._logger = get_global_dist_logger()
|
||||||
self._verbose = verbose
|
self._verbose = verbose
|
||||||
|
|
||||||
# hooks can store states in this dict, and could be consumed by other hooks
|
# hooks can store states in this dict, and could be consumed by other hooks
|
||||||
self.states = {}
|
self.states = dict()
|
||||||
|
|
||||||
# build hooks
|
# build hooks
|
||||||
self.hooks = list()
|
self.hooks = list()
|
||||||
if hooks_cfg is not None:
|
|
||||||
for cfg in hooks_cfg:
|
|
||||||
hook = build_hooks(cfg, self)
|
|
||||||
self.hooks.append(hook)
|
|
||||||
self.hooks.sort(key=lambda hook: hook.priority)
|
|
||||||
if self._verbose:
|
|
||||||
for hook in self.hooks:
|
|
||||||
self._logger.info(
|
|
||||||
f'build {hook.__class__.__name__} for train, priority = {hook.priority}', ranks=[0])
|
|
||||||
|
|
||||||
# timer
|
# multi-timer for time benchmarking
|
||||||
self._timer = get_global_multitimer()
|
self._timer = timer
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def cur_epoch(self):
|
def cur_epoch(self):
|
||||||
|
@ -74,13 +60,65 @@ class Trainer:
|
||||||
"""
|
"""
|
||||||
return self._cur_epoch
|
return self._cur_epoch
|
||||||
|
|
||||||
|
@cur_epoch.setter
|
||||||
|
def cur_epoch(self, epoch: int):
|
||||||
|
"""Set how many epochs have been processed.
|
||||||
|
"""
|
||||||
|
# allow setter for training resumption
|
||||||
|
self._cur_epoch = epoch
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def cur_step(self):
|
def cur_step(self):
|
||||||
"""Returns how many iteration steps have been processed.
|
"""Returns how many iteration steps have been processed.
|
||||||
"""
|
"""
|
||||||
return self._cur_step
|
return self._cur_step
|
||||||
|
|
||||||
def call_hooks(self, func, output=None):
|
@property
|
||||||
|
def max_epochs(self):
|
||||||
|
return self._max_epochs
|
||||||
|
|
||||||
|
@property
|
||||||
|
def max_steps(self):
|
||||||
|
return self._max_steps
|
||||||
|
|
||||||
|
@property
|
||||||
|
def steps_per_epoch(self):
|
||||||
|
return self._steps_per_epoch
|
||||||
|
|
||||||
|
@property
|
||||||
|
def engine(self):
|
||||||
|
return self._engine
|
||||||
|
|
||||||
|
@engine.setter
|
||||||
|
def engine(self, engine_: Engine):
|
||||||
|
self._engine = engine_
|
||||||
|
|
||||||
|
def _set_current_step(self, epoch: int):
|
||||||
|
"""Sets current step number.
|
||||||
|
|
||||||
|
:param epoch: Step number to be set
|
||||||
|
:type epoch: int
|
||||||
|
"""
|
||||||
|
self._cur_step = epoch * self._steps_per_epoch
|
||||||
|
|
||||||
|
def _call_timer(self, action: str, item: str, *args, **kwargs) -> None:
|
||||||
|
"""Call timer funciton with a given timer name.
|
||||||
|
|
||||||
|
:param action: Function to be called on timer
|
||||||
|
:type action: str
|
||||||
|
:param item: Name of the timer
|
||||||
|
:type item: str
|
||||||
|
"""
|
||||||
|
|
||||||
|
if self._timer is not None:
|
||||||
|
getattr(self._timer, action)(item, *args, **kwargs)
|
||||||
|
|
||||||
|
def _reset_states(self) -> None:
|
||||||
|
"""Clear trainer states
|
||||||
|
"""
|
||||||
|
self.states = dict()
|
||||||
|
|
||||||
|
def _call_hooks(self, func, output=None):
|
||||||
"""Calls specific hooks in the current time point.
|
"""Calls specific hooks in the current time point.
|
||||||
|
|
||||||
:param func: A string represents the time point
|
:param func: A string represents the time point
|
||||||
|
@ -95,161 +133,186 @@ class Trainer:
|
||||||
else:
|
else:
|
||||||
getattr(hook, func)(*output)
|
getattr(hook, func)(*output)
|
||||||
|
|
||||||
def exceed_max_step(self):
|
@staticmethod
|
||||||
"""Checks whether the trainer exceeds the maximum number of runnning iterations.
|
def _should_display_progress(display_progress: bool):
|
||||||
|
""" Only display progress on DP rank 0, TP rank 0 and PP last rank
|
||||||
"""
|
"""
|
||||||
return self._cur_step >= self._max_steps
|
return display_progress and is_dp_rank_0() and is_tp_rank_0() and is_no_pp_or_last_stage()
|
||||||
|
|
||||||
def set_epoch(self, epoch):
|
def _train_epoch(self,
|
||||||
"""Sets current epoch number.
|
train_dataloader: DataLoader,
|
||||||
|
epoch: int = None,
|
||||||
:param epoch: Epoch number to be set
|
display_progress: bool = False):
|
||||||
:type epoch: int
|
|
||||||
"""
|
|
||||||
self._cur_epoch = epoch
|
|
||||||
|
|
||||||
def _recover_steps(self):
|
|
||||||
step = self.cur_step * self._engine.schedule.num_steps
|
|
||||||
self._cur_step = step
|
|
||||||
|
|
||||||
def _set_display_progress(self, display_progress: bool):
|
|
||||||
self._display_progress = display_progress and is_dp_rank_0(
|
|
||||||
) and is_tp_rank_0() and is_no_pp_or_last_stage()
|
|
||||||
|
|
||||||
def _train_epoch(self, epoch: int = None):
|
|
||||||
# set sampler epoch
|
# set sampler epoch
|
||||||
if epoch is not None and \
|
if epoch is not None and \
|
||||||
hasattr(self._engine.train_dataloader, 'sampler') and \
|
hasattr(train_dataloader, 'sampler') and \
|
||||||
isinstance(self._engine.train_dataloader.sampler, DataParallelSampler):
|
isinstance(train_dataloader.sampler, DataParallelSampler):
|
||||||
self._engine.train_dataloader.sampler.set_epoch(epoch)
|
train_dataloader.sampler.set_epoch(epoch)
|
||||||
|
|
||||||
|
# set training state
|
||||||
self._engine.train()
|
self._engine.train()
|
||||||
|
data_iter = iter(train_dataloader)
|
||||||
progress = range(self._engine.schedule.num_steps)
|
progress = range(self._steps_per_epoch)
|
||||||
if self._display_progress:
|
if display_progress:
|
||||||
if epoch is None:
|
if epoch is None:
|
||||||
progress = tqdm(progress, desc='[Train]')
|
progress = tqdm(progress, desc='[Train]')
|
||||||
else:
|
else:
|
||||||
progress = tqdm(progress, desc=f'[Epoch {epoch} train]')
|
progress = tqdm(progress, desc=f'[Epoch {epoch} train]')
|
||||||
|
|
||||||
# train 1 epoch
|
# train 1 epoch
|
||||||
self.call_hooks('before_train_epoch')
|
self._call_hooks('before_train_epoch')
|
||||||
self._timer.start('train-epoch')
|
self._call_timer(action='start', item='train-epoch')
|
||||||
for _ in progress:
|
for i in progress:
|
||||||
|
self._call_hooks('before_train_iter')
|
||||||
|
self._call_timer(action='start', item='train-step')
|
||||||
|
|
||||||
|
if i == self._steps_per_epoch - 1:
|
||||||
|
is_last_iteration = True
|
||||||
|
else:
|
||||||
|
is_last_iteration = False
|
||||||
|
|
||||||
|
# run 1 training step
|
||||||
|
logits, label, loss = self._engine.step(data_iter, is_last_iteration)
|
||||||
|
self._call_timer(action='stop', item='train-step', keep_in_history=True)
|
||||||
|
self._call_hooks('after_train_iter', output=(logits, label, loss))
|
||||||
|
|
||||||
self._cur_step += 1
|
self._cur_step += 1
|
||||||
|
|
||||||
self.call_hooks('before_train_iter')
|
# stop when max iter is reached
|
||||||
self._timer.start('train-step')
|
if self._exceed_max_step():
|
||||||
logits, label, loss = self._engine.step()
|
|
||||||
self._timer.stop('train-step', keep_in_history=True)
|
|
||||||
self.call_hooks('after_train_iter', output=(logits, label, loss))
|
|
||||||
|
|
||||||
if self.exceed_max_step():
|
|
||||||
# stop when max iter is reached
|
|
||||||
break
|
break
|
||||||
self._timer.stop('train-epoch', keep_in_history=True)
|
|
||||||
self.call_hooks('after_train_epoch')
|
self._call_timer(action='stop', item='train-epoch', keep_in_history=True)
|
||||||
self._timer.reset('train-step')
|
self._call_hooks('after_train_epoch')
|
||||||
|
self._call_timer(action='reset', item='train-step')
|
||||||
|
|
||||||
def _eval(self,
|
def _eval(self,
|
||||||
|
test_dataloader: DataLoader,
|
||||||
epoch: int = None,
|
epoch: int = None,
|
||||||
return_loss: bool = True):
|
display_progress: bool = False):
|
||||||
# switch engine status
|
# switch engine status
|
||||||
self._engine.eval()
|
self._engine.eval()
|
||||||
|
|
||||||
self.call_hooks('before_test')
|
data_iter = iter(test_dataloader)
|
||||||
|
num_steps = len(test_dataloader)
|
||||||
|
|
||||||
|
self._call_hooks('before_test')
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
# prepare progress bar
|
# prepare progress bar
|
||||||
progress = range(self._engine.schedule.num_steps)
|
progress = range(num_steps)
|
||||||
if self._display_progress:
|
if display_progress:
|
||||||
desc = 'Evaluation'
|
desc = 'Evaluation'
|
||||||
if epoch is not None:
|
if epoch is not None:
|
||||||
desc = '[Epoch %d val]' % epoch
|
desc = '[Epoch %d val]' % epoch
|
||||||
progress = tqdm(progress, desc=desc)
|
progress = tqdm(progress, desc=desc)
|
||||||
|
|
||||||
self.call_hooks('before_test_epoch')
|
self._call_hooks('before_test_epoch')
|
||||||
self._timer.start('test-epoch')
|
self._call_timer(action='start', item='test-epoch')
|
||||||
for _ in progress:
|
for _ in progress:
|
||||||
self.call_hooks('before_test_iter')
|
self._call_hooks('before_test_iter')
|
||||||
self._timer.start('test-step')
|
self._call_timer(action='start', item='test-step')
|
||||||
logits, label, loss = self._engine.step(
|
logits, label, loss = self._engine.step(data_iter, return_loss=True)
|
||||||
return_loss=return_loss)
|
self._call_timer(action='stop', item='test-step', keep_in_history=True)
|
||||||
self._timer.stop('test-step', keep_in_history=True)
|
self._call_hooks('after_test_iter',
|
||||||
self.call_hooks('after_test_iter',
|
output=(logits, label, loss))
|
||||||
output=(logits, label, loss))
|
self._call_timer(action='stop', item='test-epoch', keep_in_history=True)
|
||||||
self._timer.stop('test-epoch', keep_in_history=True)
|
self._call_hooks('after_test_epoch')
|
||||||
self.call_hooks('after_test_epoch')
|
self._call_hooks('after_test')
|
||||||
self.call_hooks('after_test')
|
self._call_timer(action='reset', item='test-step')
|
||||||
self._timer.reset('test-step')
|
self._call_timer(action='reset', item='test-epoch')
|
||||||
self._timer.reset('test-epoch')
|
|
||||||
|
def _exceed_max_step(self):
|
||||||
|
return self._max_steps is not None and self._cur_step > self._max_steps
|
||||||
|
|
||||||
def fit(self,
|
def fit(self,
|
||||||
train_dataloader: DataLoader,
|
train_dataloader: DataLoader,
|
||||||
test_dataloader: DataLoader = None,
|
epochs: int,
|
||||||
max_epochs: int = None,
|
|
||||||
max_steps: int = None,
|
max_steps: int = None,
|
||||||
|
test_dataloader: DataLoader = None,
|
||||||
test_interval: int = 1,
|
test_interval: int = 1,
|
||||||
display_progress: bool = False):
|
hooks_cfg: dict = None,
|
||||||
|
display_progress: bool = False,
|
||||||
|
):
|
||||||
"""Trains the model to fit training data.
|
"""Trains the model to fit training data.
|
||||||
|
|
||||||
:param train_dataloader: DataLoader in training
|
:param train_dataloader: DataLoader in training
|
||||||
:param test_dataloader: DataLoader in testing
|
:param epochs: Maximum number of epoches
|
||||||
:param max_epochs: Maximum number of epoches
|
|
||||||
:param max_steps: Maximum number of running iterations
|
:param max_steps: Maximum number of running iterations
|
||||||
|
:param test_dataloader: DataLoader in testing
|
||||||
:param test_interval: Interval of testing
|
:param test_interval: Interval of testing
|
||||||
|
:param hooks_cfg: A list of hook configuration
|
||||||
:param display_progress: If True, the training progress will be printed
|
:param display_progress: If True, the training progress will be printed
|
||||||
:type train_dataloader: DataLoader
|
:type train_dataloader: DataLoader
|
||||||
:type test_dataloader: DataLoader
|
:type epochs: int
|
||||||
:type max_epochs: int
|
|
||||||
:type max_steps: int
|
:type max_steps: int
|
||||||
|
:type test_dataloader: DataLoader
|
||||||
:type test_interval: int
|
:type test_interval: int
|
||||||
|
:type hooks_cfg: dict
|
||||||
:type display_progress: bool
|
:type display_progress: bool
|
||||||
|
:type gradient_accumulation: int
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# prepare dataloaders
|
# set epochs and steps, consider gradient accumulation
|
||||||
self._train_dataloader = train_dataloader
|
self._steps_per_epoch = len(train_dataloader) // self._engine.gradient_accumulation
|
||||||
self._engine.set_dataloader(self._train_dataloader, train=True)
|
self._max_steps = max_steps
|
||||||
self._engine.train()
|
self._max_epochs = epochs
|
||||||
|
|
||||||
|
# check if testing is required
|
||||||
should_test = False
|
should_test = False
|
||||||
if test_dataloader is not None:
|
if test_dataloader is not None:
|
||||||
self._test_dataloader = test_dataloader
|
|
||||||
self._engine.set_dataloader(self._test_dataloader, train=False)
|
|
||||||
should_test = True
|
should_test = True
|
||||||
|
|
||||||
# decide the
|
display_progress = self._should_display_progress(display_progress)
|
||||||
if max_epochs is not None:
|
|
||||||
self._max_epochs = max_epochs
|
# reset hooks
|
||||||
if max_steps is not None:
|
self._reset_states()
|
||||||
self._max_steps = max_steps
|
self.hooks = list()
|
||||||
self._set_display_progress(display_progress)
|
|
||||||
|
# build hooks
|
||||||
|
if hooks_cfg is not None:
|
||||||
|
for cfg in hooks_cfg:
|
||||||
|
hook = build_hooks(cfg, self)
|
||||||
|
self.hooks.append(hook)
|
||||||
|
self.hooks.sort(key=lambda hook: hook.priority)
|
||||||
|
if self._verbose:
|
||||||
|
for hook in self.hooks:
|
||||||
|
self._logger.info(
|
||||||
|
f'build {hook.__class__.__name__} for training, priority = {hook.priority}', ranks=[0])
|
||||||
|
self._logger.info("Lower value means higher priority for calling hook function")
|
||||||
|
|
||||||
# start train
|
# start train
|
||||||
self.call_hooks('before_train')
|
self._engine.train()
|
||||||
|
self._call_hooks('before_train')
|
||||||
|
|
||||||
# recover step value if resuming training
|
# recover step value if resuming training
|
||||||
if self.cur_epoch != 0:
|
|
||||||
self._recover_steps()
|
|
||||||
|
|
||||||
last_epoch = self._cur_epoch
|
last_epoch = self._cur_epoch
|
||||||
|
if self.cur_epoch != 0:
|
||||||
|
self._set_current_step(last_epoch)
|
||||||
|
|
||||||
for epoch in range(last_epoch, self._max_epochs):
|
for epoch in range(last_epoch, epochs):
|
||||||
self._cur_epoch += 1
|
|
||||||
|
|
||||||
# train for one epoch
|
# train for one epoch
|
||||||
self._train_epoch(epoch)
|
self._train_epoch(
|
||||||
|
train_dataloader=train_dataloader,
|
||||||
|
epoch=epoch,
|
||||||
|
display_progress=display_progress
|
||||||
|
)
|
||||||
|
|
||||||
# start eval
|
# start eval
|
||||||
if should_test and epoch % test_interval == 0:
|
if should_test and epoch % test_interval == 0:
|
||||||
self._eval(epoch, return_loss=True)
|
self._eval(test_dataloader=test_dataloader,
|
||||||
|
display_progress=display_progress,
|
||||||
|
epoch=epoch,
|
||||||
|
)
|
||||||
|
|
||||||
|
self._cur_epoch += 1
|
||||||
|
|
||||||
# check for termination
|
# check for termination
|
||||||
if self.exceed_max_step():
|
if self._exceed_max_step():
|
||||||
self._logger.info(
|
self._logger.info(
|
||||||
f"Max number of steps {self._max_steps} has been reached, training is stopped automatically")
|
f"Max number of steps {max_steps} has been reached, training is stopped automatically")
|
||||||
break
|
break
|
||||||
self.call_hooks('after_train')
|
self._call_hooks('after_train')
|
||||||
self._timer.reset('train-epoch')
|
self._call_timer('reset', 'train-epoch')
|
||||||
|
|
||||||
def evaluate(self,
|
def evaluate(self,
|
||||||
test_dataloader: DataLoader,
|
test_dataloader: DataLoader,
|
||||||
|
@ -261,15 +324,13 @@ class Trainer:
|
||||||
:type test_dataloader: DataLoader
|
:type test_dataloader: DataLoader
|
||||||
:type display_progress: bool, optional
|
:type display_progress: bool, optional
|
||||||
"""
|
"""
|
||||||
# set dataloader
|
# set display
|
||||||
self._test_dataloader = test_dataloader
|
display_progress = self._should_display_progress(display_progress)
|
||||||
self._engine.set_dataloader(self._test_dataloader, train=True)
|
|
||||||
|
|
||||||
# set
|
|
||||||
self._set_display_progress(display_progress)
|
|
||||||
|
|
||||||
# eval
|
# eval
|
||||||
self._eval(return_loss=True)
|
self._eval(test_dataloader=test_dataloader,
|
||||||
|
display_progress=display_progress,
|
||||||
|
)
|
||||||
|
|
||||||
def predict(self, data: Union[Tensor, List[Tensor]]):
|
def predict(self, data: Union[Tensor, List[Tensor]]):
|
||||||
"""Uses trained model to make a prediction for a tensor or a tensor list.
|
"""Uses trained model to make a prediction for a tensor or a tensor list.
|
||||||
|
@ -289,45 +350,6 @@ class Trainer:
|
||||||
# prepare a list of (data, label) to make it iterable
|
# prepare a list of (data, label) to make it iterable
|
||||||
# for compatibility with schedule
|
# for compatibility with schedule
|
||||||
simple_dataloader = [(data, None)]
|
simple_dataloader = [(data, None)]
|
||||||
self._engine.set_dataloader(simple_dataloader)
|
data_iter = iter(simple_dataloader)
|
||||||
output, _, _ = self._engine.step(return_loss=False)
|
output, _, _ = self._engine.step(data_iter, return_loss=False)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
def save(self, path: str, suffix: str = ''):
|
|
||||||
"""Saves the model to a file.
|
|
||||||
|
|
||||||
:param path: Relative path of the file
|
|
||||||
:param suffix: Suffix of the file
|
|
||||||
:type path: str
|
|
||||||
:type suffix: str, optional
|
|
||||||
"""
|
|
||||||
save_path = get_checkpoint_path(path,
|
|
||||||
self._cur_epoch,
|
|
||||||
suffix=suffix)
|
|
||||||
save_checkpoint(save_path, self._cur_epoch, self._engine.get_model(),
|
|
||||||
self._engine.get_optimizer(),
|
|
||||||
self._engine.get_lr_scheduler())
|
|
||||||
|
|
||||||
def load(self,
|
|
||||||
path: str,
|
|
||||||
finetune: bool = False,
|
|
||||||
strict: bool = False):
|
|
||||||
"""Loads parameters to the model from a file.
|
|
||||||
|
|
||||||
:param path: Relative path of the file
|
|
||||||
:param finetune: Whether allows to load a part of the model
|
|
||||||
:param strict: Whether loads a model that has the same shape of parameters
|
|
||||||
:type path: str
|
|
||||||
:type finetune: bool, optional
|
|
||||||
:type strict: bool, optional
|
|
||||||
"""
|
|
||||||
last_epoch, _ = load_checkpoint(path,
|
|
||||||
self._engine.get_model(),
|
|
||||||
self._engine.get_optimizer(),
|
|
||||||
self._engine.get_lr_scheduler(),
|
|
||||||
finetune=finetune,
|
|
||||||
strict=strict)
|
|
||||||
if finetune:
|
|
||||||
self.set_epoch(0)
|
|
||||||
else:
|
|
||||||
self.set_epoch(last_epoch)
|
|
||||||
|
|
|
@ -2,10 +2,12 @@ from ._base_hook import BaseHook
|
||||||
from ._checkpoint_hook import SaveCheckpointHook, LoadCheckpointHook
|
from ._checkpoint_hook import SaveCheckpointHook, LoadCheckpointHook
|
||||||
from ._metric_hook import LossHook, Accuracy2DHook, AccuracyHook, MetricHook
|
from ._metric_hook import LossHook, Accuracy2DHook, AccuracyHook, MetricHook
|
||||||
from ._log_hook import LogMetricByEpochHook, TensorboardHook, LogTimingByEpochHook, LogMemoryByEpochHook
|
from ._log_hook import LogMetricByEpochHook, TensorboardHook, LogTimingByEpochHook, LogMemoryByEpochHook
|
||||||
|
from ._lr_scheduler_hook import LRSchedulerHook
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'BaseHook', 'MetricHook',
|
'BaseHook', 'MetricHook',
|
||||||
'LoadCheckpointHook', 'SaveCheckpointHook',
|
'LoadCheckpointHook', 'SaveCheckpointHook',
|
||||||
'LossHook', 'AccuracyHook', 'Accuracy2DHook',
|
'LossHook', 'AccuracyHook', 'Accuracy2DHook',
|
||||||
'LogMetricByEpochHook', 'TensorboardHook', 'LogTimingByEpochHook', 'LogMemoryByEpochHook',
|
'LogMetricByEpochHook', 'TensorboardHook', 'LogTimingByEpochHook', 'LogMemoryByEpochHook',
|
||||||
|
'LRSchedulerHook'
|
||||||
]
|
]
|
||||||
|
|
|
@ -3,13 +3,13 @@
|
||||||
|
|
||||||
import os.path as osp
|
import os.path as osp
|
||||||
|
|
||||||
import torch.distributed as dist
|
|
||||||
|
|
||||||
from colossalai.checkpointing import get_latest_checkpoint_path, get_checkpoint_path
|
|
||||||
from colossalai.registry import HOOKS
|
from colossalai.registry import HOOKS
|
||||||
from colossalai.trainer.hooks import BaseHook
|
|
||||||
from colossalai.trainer import Trainer
|
from colossalai.trainer import Trainer
|
||||||
|
from colossalai.trainer.hooks import BaseHook
|
||||||
from colossalai.utils import is_dp_rank_0
|
from colossalai.utils import is_dp_rank_0
|
||||||
|
from colossalai.utils.checkpointing import get_latest_checkpoint_path, get_checkpoint_path
|
||||||
|
from colossalai.utils.checkpointing import save_checkpoint, load_checkpoint
|
||||||
|
from ._lr_scheduler_hook import LRSchedulerHook
|
||||||
|
|
||||||
|
|
||||||
@HOOKS.register_module
|
@HOOKS.register_module
|
||||||
|
@ -33,7 +33,7 @@ class SaveCheckpointHook(BaseHook):
|
||||||
interval: int = 1,
|
interval: int = 1,
|
||||||
checkpoint_dir: str = None,
|
checkpoint_dir: str = None,
|
||||||
suffix: str = '',
|
suffix: str = '',
|
||||||
priority: int = 0):
|
priority: int = 10):
|
||||||
super().__init__(trainer=trainer, priority=priority)
|
super().__init__(trainer=trainer, priority=priority)
|
||||||
assert isinstance(trainer, Trainer), \
|
assert isinstance(trainer, Trainer), \
|
||||||
f'SaveCheckpointHook expects a Trainer, got {type(trainer)}'
|
f'SaveCheckpointHook expects a Trainer, got {type(trainer)}'
|
||||||
|
@ -41,6 +41,16 @@ class SaveCheckpointHook(BaseHook):
|
||||||
self.checkpoint_dir = checkpoint_dir
|
self.checkpoint_dir = checkpoint_dir
|
||||||
self.suffix = suffix
|
self.suffix = suffix
|
||||||
|
|
||||||
|
# get lr scheduler from the LRSchedulerHook before train
|
||||||
|
self._lr_scheduler = None
|
||||||
|
|
||||||
|
def before_train(self):
|
||||||
|
# check if lr scheduler is present in LRSchedulerHook
|
||||||
|
for hook in self.trainer.hooks:
|
||||||
|
if isinstance(hook, LRSchedulerHook):
|
||||||
|
self._lr_scheduler = hook.lr_scheduler
|
||||||
|
break
|
||||||
|
|
||||||
def after_train_epoch(self):
|
def after_train_epoch(self):
|
||||||
"""Saves the model after a training epoch.
|
"""Saves the model after a training epoch.
|
||||||
"""
|
"""
|
||||||
|
@ -48,14 +58,18 @@ class SaveCheckpointHook(BaseHook):
|
||||||
if self.trainer.cur_epoch % self.interval == 0:
|
if self.trainer.cur_epoch % self.interval == 0:
|
||||||
# only gpus with data parallel rank equals to 0 write to the disk
|
# only gpus with data parallel rank equals to 0 write to the disk
|
||||||
if is_dp_rank_0():
|
if is_dp_rank_0():
|
||||||
self.trainer.save(path=self.checkpoint_dir, suffix=self.suffix)
|
save_path = get_checkpoint_path(self.checkpoint_dir,
|
||||||
|
self.trainer.cur_epoch,
|
||||||
|
suffix=self.suffix)
|
||||||
|
|
||||||
|
save_checkpoint(save_path,
|
||||||
|
self.trainer.cur_epoch,
|
||||||
|
self.trainer.engine.model,
|
||||||
|
self.trainer.engine.optimizer,
|
||||||
|
self._lr_scheduler)
|
||||||
self.logger.info(
|
self.logger.info(
|
||||||
f'checkpoint for epoch {self.trainer.cur_epoch} is saved to {self.checkpoint_dir}')
|
f'checkpoint for epoch {self.trainer.cur_epoch} is saved to {self.checkpoint_dir}')
|
||||||
|
|
||||||
# wait until everyone is done
|
|
||||||
if dist.is_initialized():
|
|
||||||
dist.barrier()
|
|
||||||
|
|
||||||
|
|
||||||
@HOOKS.register_module
|
@HOOKS.register_module
|
||||||
class LoadCheckpointHook(BaseHook):
|
class LoadCheckpointHook(BaseHook):
|
||||||
|
@ -81,30 +95,46 @@ class LoadCheckpointHook(BaseHook):
|
||||||
epoch: int = -1,
|
epoch: int = -1,
|
||||||
finetune: bool = False,
|
finetune: bool = False,
|
||||||
strict: bool = False,
|
strict: bool = False,
|
||||||
priority: int = 10) -> None:
|
suffix: str = '',
|
||||||
|
priority: int = 0) -> None:
|
||||||
|
super().__init__(trainer=trainer, priority=priority)
|
||||||
assert isinstance(trainer, Trainer), \
|
assert isinstance(trainer, Trainer), \
|
||||||
f'LoadLatestCheckpointHook excepts a Trainer, got {type(trainer)}'
|
f'LoadLatestCheckpointHook excepts a Trainer, got {type(trainer)}'
|
||||||
self.epoch = epoch
|
self.epoch = epoch
|
||||||
self.checkpoint_dir = checkpoint_dir
|
self.checkpoint_dir = checkpoint_dir
|
||||||
self.finetune = finetune
|
self.finetune = finetune
|
||||||
|
self.suffix = suffix
|
||||||
self.strict = strict
|
self.strict = strict
|
||||||
super().__init__(trainer=trainer, priority=priority)
|
|
||||||
|
|
||||||
def before_train(self):
|
def before_train(self):
|
||||||
"""Loads parameters to the model before training.
|
"""Loads parameters to the model before training.
|
||||||
"""
|
"""
|
||||||
|
# check if lr scheduler is present in LRSchedulerHook
|
||||||
|
lr_scheduler = None
|
||||||
|
for hook in self.trainer.hooks:
|
||||||
|
if isinstance(hook, LRSchedulerHook):
|
||||||
|
lr_scheduler = hook.lr_scheduler
|
||||||
|
break
|
||||||
|
|
||||||
|
# use latest checkpoint if epoch = -1
|
||||||
if self.epoch == -1:
|
if self.epoch == -1:
|
||||||
path = get_latest_checkpoint_path(self.checkpoint_dir)
|
path = get_latest_checkpoint_path(self.checkpoint_dir, suffix=self.suffix)
|
||||||
else:
|
else:
|
||||||
path = get_checkpoint_path(self.checkpoint_dir, epoch=self.epoch)
|
path = get_checkpoint_path(self.checkpoint_dir, epoch=self.epoch, suffix=self.suffix)
|
||||||
|
|
||||||
if osp.exists(path):
|
if osp.exists(path):
|
||||||
self.trainer.load(
|
last_epoch, _ = load_checkpoint(path,
|
||||||
path, finetune=self.finetune, strict=self.strict)
|
self.trainer.engine.model,
|
||||||
|
self.trainer.engine.optimizer,
|
||||||
|
lr_scheduler,
|
||||||
|
finetune=self.finetune,
|
||||||
|
strict=self.strict)
|
||||||
|
if self.finetune:
|
||||||
|
self.trainer.cur_epoch = 0
|
||||||
|
else:
|
||||||
|
self.trainer.cur_epoch = last_epoch
|
||||||
|
|
||||||
self.logger.info(
|
self.logger.info(
|
||||||
f'loaded checkpoint from {path}')
|
f'loaded checkpoint from {path}')
|
||||||
else:
|
else:
|
||||||
raise FileNotFoundError(f'checkpoint is not found at {path}')
|
raise FileNotFoundError(f'checkpoint is not found at {path}')
|
||||||
|
|
||||||
# Some utilities want to load a checkpoint without distributed being initialized
|
|
||||||
if dist.is_initialized():
|
|
||||||
dist.barrier()
|
|
||||||
|
|
|
@ -5,7 +5,7 @@ import os
|
||||||
import os.path as osp
|
import os.path as osp
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from tensorboardX import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
from colossalai.context import ParallelMode
|
from colossalai.context import ParallelMode
|
||||||
from colossalai.core import global_context as gpc
|
from colossalai.core import global_context as gpc
|
||||||
|
@ -13,7 +13,7 @@ from colossalai.registry import HOOKS
|
||||||
from colossalai.trainer._trainer import Trainer
|
from colossalai.trainer._trainer import Trainer
|
||||||
from colossalai.utils import get_global_multitimer, set_global_multitimer_status, report_memory_usage, is_dp_rank_0, \
|
from colossalai.utils import get_global_multitimer, set_global_multitimer_status, report_memory_usage, is_dp_rank_0, \
|
||||||
is_tp_rank_0, is_no_pp_or_last_stage
|
is_tp_rank_0, is_no_pp_or_last_stage
|
||||||
from ._metric_hook import MetricHook
|
from ._base_hook import BaseHook
|
||||||
|
|
||||||
|
|
||||||
def _format_number(val):
|
def _format_number(val):
|
||||||
|
@ -24,7 +24,7 @@ def _format_number(val):
|
||||||
return val
|
return val
|
||||||
|
|
||||||
|
|
||||||
class EpochIntervalHook(MetricHook):
|
class EpochIntervalHook(BaseHook):
|
||||||
def __init__(self, trainer: Trainer, interval: int = 1, priority: int = 1):
|
def __init__(self, trainer: Trainer, interval: int = 1, priority: int = 1):
|
||||||
super().__init__(trainer, priority)
|
super().__init__(trainer, priority)
|
||||||
self._interval = interval
|
self._interval = interval
|
||||||
|
@ -45,7 +45,7 @@ class LogMetricByEpochHook(EpochIntervalHook):
|
||||||
:type priority: int, optional
|
:type priority: int, optional
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, trainer: Trainer, interval: int = 1, priority: int = 1) -> None:
|
def __init__(self, trainer: Trainer, interval: int = 1, priority: int = 10) -> None:
|
||||||
super().__init__(trainer=trainer, interval=interval, priority=priority)
|
super().__init__(trainer=trainer, interval=interval, priority=priority)
|
||||||
self._is_rank_to_log = is_dp_rank_0() and is_tp_rank_0() and is_no_pp_or_last_stage()
|
self._is_rank_to_log = is_dp_rank_0() and is_tp_rank_0() and is_no_pp_or_last_stage()
|
||||||
|
|
||||||
|
@ -74,7 +74,7 @@ class LogMetricByEpochHook(EpochIntervalHook):
|
||||||
|
|
||||||
|
|
||||||
@HOOKS.register_module
|
@HOOKS.register_module
|
||||||
class TensorboardHook(MetricHook):
|
class TensorboardHook(BaseHook):
|
||||||
"""Specialized Hook to record the metric to Tensorboard.
|
"""Specialized Hook to record the metric to Tensorboard.
|
||||||
|
|
||||||
:param trainer: Trainer attached with current hook
|
:param trainer: Trainer attached with current hook
|
||||||
|
@ -85,59 +85,71 @@ class TensorboardHook(MetricHook):
|
||||||
:type priority: int, optional
|
:type priority: int, optional
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, trainer: Trainer, log_dir: str, priority: int = 1) -> None:
|
def __init__(self,
|
||||||
|
trainer: Trainer,
|
||||||
|
log_dir: str,
|
||||||
|
dp_rank_0_only: bool = True,
|
||||||
|
tp_rank_0_only: bool = True,
|
||||||
|
priority: int = 10,
|
||||||
|
) -> None:
|
||||||
super().__init__(trainer=trainer, priority=priority)
|
super().__init__(trainer=trainer, priority=priority)
|
||||||
self._is_rank_to_log = is_no_pp_or_last_stage()
|
|
||||||
|
|
||||||
if self._is_rank_to_log:
|
# create log dir
|
||||||
|
if not gpc.is_initialized(ParallelMode.GLOBAL) or gpc.get_global_rank() == 0:
|
||||||
|
os.makedirs(log_dir, exist_ok=True)
|
||||||
|
|
||||||
|
# determine the ranks to generate tensorboard logs
|
||||||
|
self._is_valid_rank_to_log = is_no_pp_or_last_stage()
|
||||||
|
|
||||||
|
if dp_rank_0_only:
|
||||||
|
self._is_valid_rank_to_log = self._is_valid_rank_to_log and is_dp_rank_0()
|
||||||
|
|
||||||
|
if tp_rank_0_only:
|
||||||
|
self._is_valid_rank_to_log = self._is_valid_rank_to_log and is_tp_rank_0()
|
||||||
|
|
||||||
|
if self._is_valid_rank_to_log:
|
||||||
# create workspace on only one rank
|
# create workspace on only one rank
|
||||||
if gpc.is_initialized(ParallelMode.GLOBAL):
|
if gpc.is_initialized(ParallelMode.GLOBAL):
|
||||||
rank = gpc.get_global_rank()
|
rank = gpc.get_global_rank()
|
||||||
else:
|
else:
|
||||||
rank = 0
|
rank = 0
|
||||||
|
|
||||||
log_dir = osp.join(log_dir, f'rank_{rank}')
|
|
||||||
|
|
||||||
# create workspace
|
# create workspace
|
||||||
if not osp.exists(log_dir):
|
log_dir = osp.join(log_dir, f'rank_{rank}')
|
||||||
os.makedirs(log_dir)
|
os.makedirs(log_dir, exist_ok=True)
|
||||||
|
|
||||||
self.writer = SummaryWriter(
|
self.writer = SummaryWriter(
|
||||||
log_dir=log_dir, filename_suffix=f'_rank_{rank}')
|
log_dir=log_dir, filename_suffix=f'_rank_{rank}')
|
||||||
|
|
||||||
def after_train_iter(self, *args):
|
def _log_by_iter(self, mode: str):
|
||||||
for metric_name, metric_calculator in self.trainer.states['metrics']['train'].items():
|
for metric_name, metric_calculator in self.trainer.states['metrics'][mode].items():
|
||||||
if metric_calculator.epoch_only:
|
if metric_calculator.epoch_only:
|
||||||
continue
|
continue
|
||||||
val = metric_calculator.get_last_step_value()
|
val = metric_calculator.get_last_step_value()
|
||||||
if self._is_rank_to_log:
|
|
||||||
self.writer.add_scalar(
|
|
||||||
f'{metric_name}/train', val, self.trainer.cur_step)
|
|
||||||
|
|
||||||
def after_test_iter(self, *args):
|
if self._is_valid_rank_to_log:
|
||||||
for metric_name, metric_calculator in self.trainer.states['metrics']['test'].items():
|
self.writer.add_scalar(f'{metric_name}/{mode}', val,
|
||||||
if metric_calculator.epoch_only:
|
|
||||||
continue
|
|
||||||
val = metric_calculator.get_last_step_value()
|
|
||||||
if self._is_rank_to_log:
|
|
||||||
self.writer.add_scalar(f'{metric_name}/test', val,
|
|
||||||
self.trainer.cur_step)
|
self.trainer.cur_step)
|
||||||
|
|
||||||
def after_test_epoch(self):
|
def _log_by_epoch(self, mode: str):
|
||||||
for metric_name, metric_calculator in self.trainer.states['metrics']['test'].items():
|
for metric_name, metric_calculator in self.trainer.states['metrics'][mode].items():
|
||||||
if metric_calculator.epoch_only:
|
if metric_calculator.epoch_only:
|
||||||
val = metric_calculator.get_accumulated_value()
|
val = metric_calculator.get_accumulated_value()
|
||||||
if self._is_rank_to_log:
|
if self._is_valid_rank_to_log:
|
||||||
self.writer.add_scalar(f'{metric_name}/test', val,
|
self.writer.add_scalar(f'{metric_name}/{mode}', val,
|
||||||
self.trainer.cur_step)
|
self.trainer.cur_step)
|
||||||
|
|
||||||
|
def after_test_iter(self, *args):
|
||||||
|
self._log_by_iter(mode='test')
|
||||||
|
|
||||||
|
def after_test_epoch(self):
|
||||||
|
self._log_by_epoch(mode='test')
|
||||||
|
|
||||||
|
def after_train_iter(self, *args):
|
||||||
|
self._log_by_iter(mode='train')
|
||||||
|
|
||||||
def after_train_epoch(self):
|
def after_train_epoch(self):
|
||||||
for metric_name, metric_calculator in self.trainer.states['metrics']['train'].items():
|
self._log_by_epoch(mode='train')
|
||||||
if metric_calculator.epoch_only:
|
|
||||||
val = metric_calculator.get_accumulated_value()
|
|
||||||
if self._is_rank_to_log:
|
|
||||||
self.writer.add_scalar(f'{metric_name}/train', val,
|
|
||||||
self.trainer.cur_step)
|
|
||||||
|
|
||||||
|
|
||||||
@HOOKS.register_module
|
@HOOKS.register_module
|
||||||
|
@ -157,7 +169,7 @@ class LogTimingByEpochHook(EpochIntervalHook):
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
trainer: Trainer,
|
trainer: Trainer,
|
||||||
interval: int = 1,
|
interval: int = 1,
|
||||||
priority: int = 1,
|
priority: int = 10,
|
||||||
log_eval: bool = True
|
log_eval: bool = True
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(trainer=trainer, interval=interval, priority=priority)
|
super().__init__(trainer=trainer, interval=interval, priority=priority)
|
||||||
|
@ -217,7 +229,7 @@ class LogMemoryByEpochHook(EpochIntervalHook):
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
trainer: Trainer,
|
trainer: Trainer,
|
||||||
interval: int = 1,
|
interval: int = 1,
|
||||||
priority: int = 1,
|
priority: int = 10,
|
||||||
log_eval: bool = True
|
log_eval: bool = True
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(trainer=trainer, interval=interval, priority=priority)
|
super().__init__(trainer=trainer, interval=interval, priority=priority)
|
||||||
|
|
|
@ -0,0 +1,58 @@
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
|
from colossalai.builder import build_lr_scheduler
|
||||||
|
from colossalai.registry import HOOKS
|
||||||
|
from ._metric_hook import MetricHook
|
||||||
|
from .._trainer import Trainer
|
||||||
|
from ..metric import LearningRate
|
||||||
|
|
||||||
|
|
||||||
|
@HOOKS.register_module
|
||||||
|
class LRSchedulerHook(MetricHook):
|
||||||
|
"""Build LR scheduler
|
||||||
|
|
||||||
|
:param trainer: Trainer attached with current hook
|
||||||
|
:type trainer: Trainer
|
||||||
|
:param lr_scheduler_cfg: The config of LR scheduler
|
||||||
|
:type lr_scheduler_cfg: dict
|
||||||
|
:param by_epoch: If `True`, the LR will be scheduled every epoch. Else, the LR will be scheduled every batch. Defaults to `True`.
|
||||||
|
:type by_epoch: bool
|
||||||
|
:param priority: Priority in the printing, hooks with small priority will be printed in front
|
||||||
|
:type priority: int, optional
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
trainer: Trainer,
|
||||||
|
lr_scheduler_cfg: dict,
|
||||||
|
by_epoch: bool = True,
|
||||||
|
store_lr_in_state: bool = True,
|
||||||
|
priority: int = 1,
|
||||||
|
):
|
||||||
|
super().__init__(trainer=trainer, priority=priority)
|
||||||
|
self.by_epoch = by_epoch
|
||||||
|
|
||||||
|
if by_epoch:
|
||||||
|
total_steps = trainer.max_epochs
|
||||||
|
else:
|
||||||
|
total_steps = trainer.max_epochs * trainer.steps_per_epoch
|
||||||
|
if trainer.max_steps is not None:
|
||||||
|
total_steps = min(total_steps, trainer.max_steps)
|
||||||
|
|
||||||
|
lr_scheduler_cfg['total_steps'] = total_steps
|
||||||
|
|
||||||
|
self.lr_scheduler = build_lr_scheduler(
|
||||||
|
lr_scheduler_cfg, trainer.engine.optimizer)
|
||||||
|
|
||||||
|
if store_lr_in_state:
|
||||||
|
self.trainer.states['metrics']['train']['lr'] = LearningRate(epoch_only=by_epoch,
|
||||||
|
initial_lr=self.lr_scheduler.get_lr()[0])
|
||||||
|
|
||||||
|
def after_train_epoch(self):
|
||||||
|
if self.by_epoch:
|
||||||
|
self.lr_scheduler.step()
|
||||||
|
self.trainer.states['metrics']['train']['lr'].update(self.lr_scheduler.get_lr()[0])
|
||||||
|
|
||||||
|
def after_train_iter(self, output: Tensor, label: Tensor, loss: Tensor):
|
||||||
|
if not self.by_epoch:
|
||||||
|
self.lr_scheduler.step()
|
||||||
|
self.trainer.states['metrics']['train']['lr'].update(self.lr_scheduler.get_lr()[0])
|
|
@ -21,9 +21,12 @@ class MetricHook(BaseHook):
|
||||||
:type priority: int
|
:type priority: int
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, trainer: Trainer, priority: int):
|
def __init__(self,
|
||||||
|
trainer: Trainer,
|
||||||
|
priority: int,
|
||||||
|
):
|
||||||
super().__init__(trainer, priority)
|
super().__init__(trainer, priority)
|
||||||
self._is_stage_to_log = is_no_pp_or_last_stage()
|
self._is_stage_to_compute = is_no_pp_or_last_stage()
|
||||||
self._check_metric_states_initialization()
|
self._check_metric_states_initialization()
|
||||||
|
|
||||||
def _check_metric_states_initialization(self):
|
def _check_metric_states_initialization(self):
|
||||||
|
@ -41,33 +44,34 @@ class LossHook(MetricHook):
|
||||||
:type priority: int, optional
|
:type priority: int, optional
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, trainer: Trainer, priority: int = 10):
|
def __init__(self, trainer: Trainer, priority: int = 0):
|
||||||
super().__init__(trainer, priority)
|
super().__init__(trainer, priority)
|
||||||
|
|
||||||
if self._is_stage_to_log:
|
if self._is_stage_to_compute:
|
||||||
self.metric = Loss(epoch_only=False)
|
self.train_loss = Loss(epoch_only=False)
|
||||||
|
self.test_loss = Loss(epoch_only=True)
|
||||||
|
|
||||||
# register the metric calculator
|
# register the metric calculator
|
||||||
self.trainer.states['metrics']['train'][
|
self.trainer.states['metrics']['train'][
|
||||||
self.metric.__class__.__name__] = self.metric
|
self.train_loss.__class__.__name__] = self.train_loss
|
||||||
self.trainer.states['metrics']['test'][
|
self.trainer.states['metrics']['test'][
|
||||||
self.metric.__class__.__name__] = self.metric
|
self.test_loss.__class__.__name__] = self.test_loss
|
||||||
|
|
||||||
def before_train_epoch(self):
|
def before_train_epoch(self):
|
||||||
if self._is_stage_to_log:
|
if self._is_stage_to_compute:
|
||||||
self.metric.reset()
|
self.train_loss.reset()
|
||||||
|
|
||||||
def after_train_iter(self, logits, label, loss):
|
def after_train_iter(self, logits, label, loss):
|
||||||
if self._is_stage_to_log:
|
if self._is_stage_to_compute:
|
||||||
self.metric.update(loss)
|
self.train_loss.update(loss)
|
||||||
|
|
||||||
def before_test_epoch(self):
|
def before_test_epoch(self):
|
||||||
if self._is_stage_to_log:
|
if self._is_stage_to_compute:
|
||||||
self.metric.reset()
|
self.test_loss.reset()
|
||||||
|
|
||||||
def after_test_iter(self, logits, label, loss):
|
def after_test_iter(self, logits, label, loss):
|
||||||
if self._is_stage_to_log:
|
if self._is_stage_to_compute:
|
||||||
self.metric.update(loss)
|
self.test_loss.update(loss)
|
||||||
|
|
||||||
|
|
||||||
@HOOKS.register_module
|
@HOOKS.register_module
|
||||||
|
@ -81,10 +85,10 @@ class Accuracy2DHook(MetricHook):
|
||||||
:type priority: int, optional
|
:type priority: int, optional
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, trainer: Trainer, priority: int = 10):
|
def __init__(self, trainer: Trainer, priority: int = 0):
|
||||||
super().__init__(trainer, priority)
|
super().__init__(trainer, priority)
|
||||||
|
|
||||||
if self._is_stage_to_log:
|
if self._is_stage_to_compute:
|
||||||
self.metric = Accuracy2D(epoch_only=True)
|
self.metric = Accuracy2D(epoch_only=True)
|
||||||
|
|
||||||
# register the metric
|
# register the metric
|
||||||
|
@ -92,20 +96,20 @@ class Accuracy2DHook(MetricHook):
|
||||||
self.metric.__class__.__name__] = self.metric
|
self.metric.__class__.__name__] = self.metric
|
||||||
|
|
||||||
def before_test(self):
|
def before_test(self):
|
||||||
if self._is_stage_to_log:
|
if self._is_stage_to_compute:
|
||||||
self.metric.reset()
|
self.metric.reset()
|
||||||
|
|
||||||
def after_test_iter(self, logits, label, *args):
|
def after_test_iter(self, logits, label, *args):
|
||||||
if self._is_stage_to_log:
|
if self._is_stage_to_compute:
|
||||||
self.metric.update(logits, label)
|
self.metric.update(logits, label)
|
||||||
|
|
||||||
|
|
||||||
@HOOKS.register_module
|
@HOOKS.register_module
|
||||||
class Accuracy2p5DHook(MetricHook):
|
class Accuracy2p5DHook(MetricHook):
|
||||||
def __init__(self, trainer: Trainer, priority: int = 10):
|
def __init__(self, trainer: Trainer, priority: int = 0):
|
||||||
super().__init__(trainer, priority)
|
super().__init__(trainer, priority)
|
||||||
|
|
||||||
if self._is_stage_to_log:
|
if self._is_stage_to_compute:
|
||||||
self.metric = Accuracy2p5D(epoch_only=True)
|
self.metric = Accuracy2p5D(epoch_only=True)
|
||||||
|
|
||||||
# register the metric
|
# register the metric
|
||||||
|
@ -113,11 +117,11 @@ class Accuracy2p5DHook(MetricHook):
|
||||||
self.metric.__class__.__name__] = self.metric
|
self.metric.__class__.__name__] = self.metric
|
||||||
|
|
||||||
def before_test(self):
|
def before_test(self):
|
||||||
if self._is_stage_to_log:
|
if self._is_stage_to_compute:
|
||||||
self.metric.reset()
|
self.metric.reset()
|
||||||
|
|
||||||
def after_test_iter(self, logits, label, *args):
|
def after_test_iter(self, logits, label, *args):
|
||||||
if self._is_stage_to_log:
|
if self._is_stage_to_compute:
|
||||||
self.metric.update(logits, label)
|
self.metric.update(logits, label)
|
||||||
|
|
||||||
|
|
||||||
|
@ -138,7 +142,7 @@ class Accuracy3DHook(MetricHook):
|
||||||
priority: int = 10):
|
priority: int = 10):
|
||||||
super().__init__(trainer, priority)
|
super().__init__(trainer, priority)
|
||||||
|
|
||||||
if self._is_stage_to_log:
|
if self._is_stage_to_compute:
|
||||||
self.metric = Accuracy3D(epoch_only=True,
|
self.metric = Accuracy3D(epoch_only=True,
|
||||||
input_parallel_mode=input_parallel_mode,
|
input_parallel_mode=input_parallel_mode,
|
||||||
weight_parallel_mode=weight_parallel_mode)
|
weight_parallel_mode=weight_parallel_mode)
|
||||||
|
@ -148,11 +152,11 @@ class Accuracy3DHook(MetricHook):
|
||||||
self.metric.__class__.__name__] = self.metric
|
self.metric.__class__.__name__] = self.metric
|
||||||
|
|
||||||
def before_test(self):
|
def before_test(self):
|
||||||
if self._is_stage_to_log:
|
if self._is_stage_to_compute:
|
||||||
self.metric.reset()
|
self.metric.reset()
|
||||||
|
|
||||||
def after_test_iter(self, logits, label, *args):
|
def after_test_iter(self, logits, label, *args):
|
||||||
if self._is_stage_to_log:
|
if self._is_stage_to_compute:
|
||||||
self.metric.update(logits, label)
|
self.metric.update(logits, label)
|
||||||
|
|
||||||
|
|
||||||
|
@ -166,10 +170,10 @@ class AccuracyHook(MetricHook):
|
||||||
:type priority: int
|
:type priority: int
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, trainer: Trainer, priority: int = 10):
|
def __init__(self, trainer: Trainer, priority: int = 0):
|
||||||
super().__init__(trainer, priority)
|
super().__init__(trainer, priority)
|
||||||
|
|
||||||
if self._is_stage_to_log:
|
if self._is_stage_to_compute:
|
||||||
self.metric = Accuracy(epoch_only=True)
|
self.metric = Accuracy(epoch_only=True)
|
||||||
|
|
||||||
# register the metric
|
# register the metric
|
||||||
|
@ -177,9 +181,9 @@ class AccuracyHook(MetricHook):
|
||||||
self.metric.__class__.__name__] = self.metric
|
self.metric.__class__.__name__] = self.metric
|
||||||
|
|
||||||
def before_test(self):
|
def before_test(self):
|
||||||
if self._is_stage_to_log:
|
if self._is_stage_to_compute:
|
||||||
self.metric.reset()
|
self.metric.reset()
|
||||||
|
|
||||||
def after_test_iter(self, logits, label, *args):
|
def after_test_iter(self, logits, label, *args):
|
||||||
if self._is_stage_to_log:
|
if self._is_stage_to_compute:
|
||||||
self.metric.update(logits, label)
|
self.metric.update(logits, label)
|
||||||
|
|
|
@ -126,6 +126,33 @@ class Loss(Metric):
|
||||||
return a < b
|
return a < b
|
||||||
|
|
||||||
|
|
||||||
|
class LearningRate(Metric):
|
||||||
|
"""A metric collector for learning rate.
|
||||||
|
|
||||||
|
:param epoch_only: Whether the metric only read for the full epoch
|
||||||
|
:type epoch_only: bool
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, epoch_only: bool, initial_lr: float = 0.):
|
||||||
|
super().__init__(epoch_only=epoch_only)
|
||||||
|
self.lr = 0.
|
||||||
|
|
||||||
|
def reset(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def update(self, lr) -> None:
|
||||||
|
self.lr = lr
|
||||||
|
|
||||||
|
def get_last_step_value(self):
|
||||||
|
return self.lr
|
||||||
|
|
||||||
|
def get_accumulated_value(self):
|
||||||
|
return self.lr
|
||||||
|
|
||||||
|
def is_better(a, b) -> bool:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class Accuracy(Metric):
|
class Accuracy(Metric):
|
||||||
"""A metric collector for accuracy. It only works for classification
|
"""A metric collector for accuracy. It only works for classification
|
||||||
tasks.
|
tasks.
|
||||||
|
|
|
@ -5,9 +5,9 @@ from typing import Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from .context import Config
|
from colossalai.context import Config
|
||||||
from .context.parallel_mode import ParallelMode
|
from colossalai.context.parallel_mode import ParallelMode
|
||||||
from .core import global_context as gpc
|
from colossalai.core import global_context as gpc
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'get_checkpoint_path',
|
'get_checkpoint_path',
|
|
@ -27,7 +27,7 @@ def sync_model_param_in_dp(model):
|
||||||
:param model: A pyTorch nn.model on whose parameters you check the consistency
|
:param model: A pyTorch nn.model on whose parameters you check the consistency
|
||||||
'''
|
'''
|
||||||
|
|
||||||
if gpc.is_initialized(ParallelMode.DATA) and gpc.get_world_size(ParallelMode.DATA) > 2:
|
if gpc.is_initialized(ParallelMode.DATA) and gpc.get_world_size(ParallelMode.DATA) > 1:
|
||||||
for param in model.parameters():
|
for param in model.parameters():
|
||||||
ranks = gpc.get_ranks_in_group(ParallelMode.DATA)
|
ranks = gpc.get_ranks_in_group(ParallelMode.DATA)
|
||||||
dist.broadcast(param, src=ranks[0], group=gpc.get_group(ParallelMode.DATA))
|
dist.broadcast(param, src=ranks[0], group=gpc.get_group(ParallelMode.DATA))
|
||||||
|
|
|
@ -4,6 +4,7 @@ import os
|
||||||
|
|
||||||
IMG_SIZE = 224
|
IMG_SIZE = 224
|
||||||
BATCH_SIZE = 256
|
BATCH_SIZE = 256
|
||||||
|
NUM_EPOCHS = 100
|
||||||
|
|
||||||
model = dict(
|
model = dict(
|
||||||
type='VanillaResNet',
|
type='VanillaResNet',
|
||||||
|
@ -67,8 +68,6 @@ loss = dict(
|
||||||
type='CrossEntropyLoss'
|
type='CrossEntropyLoss'
|
||||||
)
|
)
|
||||||
|
|
||||||
max_epochs = 100
|
|
||||||
|
|
||||||
from colossalai.engine import AMP_TYPE
|
from colossalai.engine import AMP_TYPE
|
||||||
|
|
||||||
fp16 = dict(
|
fp16 = dict(
|
||||||
|
|
|
@ -1,21 +1,20 @@
|
||||||
#!/usr/bin/env python
|
#!/usr/bin/env python
|
||||||
# -*- encoding: utf-8 -*-
|
# -*- encoding: utf-8 -*-
|
||||||
|
|
||||||
|
NUM_EPOCH = int
|
||||||
|
|
||||||
model = dict()
|
model = dict()
|
||||||
train_data = dict()
|
train_data = dict()
|
||||||
test_data = dict()
|
test_data = dict()
|
||||||
optimizer = dict()
|
optimizer = dict()
|
||||||
loss = dict()
|
loss = dict()
|
||||||
lr_scheduler = dict()
|
|
||||||
|
|
||||||
fp16 = dict()
|
fp16 = dict()
|
||||||
zero = dict()
|
zero = dict()
|
||||||
|
|
||||||
gradient_handler = []
|
gradient_handler = []
|
||||||
parallel = dict()
|
parallel = dict()
|
||||||
|
hooks = []
|
||||||
num_epochs = int
|
|
||||||
num_steps = int
|
|
||||||
|
|
||||||
cudnn_benchmark = True
|
cudnn_benchmark = True
|
||||||
cudnn_deterministic = False
|
cudnn_deterministic = False
|
||||||
|
|
|
@ -8,10 +8,11 @@ BATCH_SIZE = 512
|
||||||
IMG_SIZE = 32
|
IMG_SIZE = 32
|
||||||
PATCH_SIZE = 4
|
PATCH_SIZE = 4
|
||||||
DIM = 512
|
DIM = 512
|
||||||
NUM_ATTENTION_HEADS = 8
|
NUM_ATTENTION_HEADS = 2
|
||||||
SUMMA_DIM = 2
|
SUMMA_DIM = 2
|
||||||
NUM_CLASSES = 10
|
NUM_CLASSES = 10
|
||||||
DEPTH = 6
|
DEPTH = 1
|
||||||
|
NUM_EPOCHS = 60
|
||||||
|
|
||||||
train_data = dict(
|
train_data = dict(
|
||||||
dataset=dict(
|
dataset=dict(
|
||||||
|
@ -127,14 +128,22 @@ hooks = [
|
||||||
dict(type='LogMetricByEpochHook'),
|
dict(type='LogMetricByEpochHook'),
|
||||||
dict(type='Accuracy2DHook'),
|
dict(type='Accuracy2DHook'),
|
||||||
dict(type='LossHook'),
|
dict(type='LossHook'),
|
||||||
dict(type='TensorboardHook', log_dir='./tfb_logs'),
|
dict(
|
||||||
|
type='LRSchedulerHook',
|
||||||
|
by_epoch=True,
|
||||||
|
lr_scheduler_cfg=dict(
|
||||||
|
type='LinearWarmupLR',
|
||||||
|
warmup_steps=5
|
||||||
|
)
|
||||||
|
),
|
||||||
|
dict(type='TensorboardHook', log_dir='./tb_logs'),
|
||||||
# dict(type='SaveCheckpointHook', interval=5, checkpoint_dir='./ckpt'),
|
# dict(type='SaveCheckpointHook', interval=5, checkpoint_dir='./ckpt'),
|
||||||
# dict(type='LoadCheckpointHook', epoch=20, checkpoint_dir='./ckpt')
|
# dict(type='LoadCheckpointHook', epoch=20, checkpoint_dir='./ckpt')
|
||||||
]
|
]
|
||||||
|
|
||||||
parallel = dict(
|
parallel = dict(
|
||||||
pipeline=dict(size=1),
|
pipeline=dict(size=1),
|
||||||
tensor=dict(size=4, mode='2d'),
|
tensor=dict(size=1, mode='2d'),
|
||||||
)
|
)
|
||||||
|
|
||||||
# for fp16 training
|
# for fp16 training
|
||||||
|
@ -144,17 +153,11 @@ parallel = dict(
|
||||||
# initial_scale=2 ** 8
|
# initial_scale=2 ** 8
|
||||||
# )
|
# )
|
||||||
|
|
||||||
lr_scheduler = dict(
|
|
||||||
type='LinearWarmupLR',
|
|
||||||
warmup_epochs=5
|
|
||||||
)
|
|
||||||
|
|
||||||
# only needed when pipeline parallel is used
|
# only needed when pipeline parallel is used
|
||||||
# schedule = dict(
|
# schedule = dict(
|
||||||
# num_microbatches=8
|
# num_microbatches=8
|
||||||
# )
|
# )
|
||||||
|
|
||||||
num_epochs = 60
|
|
||||||
|
|
||||||
logging = dict(
|
logging = dict(
|
||||||
root_path='./logs'
|
root_path='./logs'
|
||||||
|
|
|
@ -14,6 +14,7 @@ except:
|
||||||
|
|
||||||
BATCH_SIZE = 512
|
BATCH_SIZE = 512
|
||||||
IMG_SIZE = 32
|
IMG_SIZE = 32
|
||||||
|
NUM_EPOCHS = 60
|
||||||
|
|
||||||
train_data = dict(
|
train_data = dict(
|
||||||
dataset=dict(
|
dataset=dict(
|
||||||
|
@ -83,6 +84,14 @@ hooks = [
|
||||||
),
|
),
|
||||||
dict(type='LossHook'),
|
dict(type='LossHook'),
|
||||||
dict(type='TensorboardHook', log_dir='./tfb_logs'),
|
dict(type='TensorboardHook', log_dir='./tfb_logs'),
|
||||||
|
dict(
|
||||||
|
type='LRSchedulerHook',
|
||||||
|
by_epoch=True,
|
||||||
|
lr_scheduler_cfg=dict(
|
||||||
|
type='LinearWarmupLR',
|
||||||
|
warmup_steps=5
|
||||||
|
)
|
||||||
|
),
|
||||||
# dict(type='SaveCheckpointHook', interval=5, checkpoint_dir='./ckpt'),
|
# dict(type='SaveCheckpointHook', interval=5, checkpoint_dir='./ckpt'),
|
||||||
# dict(type='LoadCheckpointHook', epoch=20, checkpoint_dir='./ckpt')
|
# dict(type='LoadCheckpointHook', epoch=20, checkpoint_dir='./ckpt')
|
||||||
]
|
]
|
||||||
|
@ -97,13 +106,6 @@ fp16 = dict(
|
||||||
initial_scale=2 ** 8
|
initial_scale=2 ** 8
|
||||||
)
|
)
|
||||||
|
|
||||||
lr_scheduler = dict(
|
|
||||||
type='LinearWarmupLR',
|
|
||||||
warmup_epochs=5
|
|
||||||
)
|
|
||||||
|
|
||||||
num_epochs = 60
|
|
||||||
|
|
||||||
logging = dict(
|
logging = dict(
|
||||||
root_path='./logs'
|
root_path='./logs'
|
||||||
)
|
)
|
||||||
|
|
|
@ -0,0 +1,5 @@
|
||||||
|
colossalai.engine.amp.amp\_type
|
||||||
|
===============================
|
||||||
|
|
||||||
|
.. automodule:: colossalai.engine.amp.amp_type
|
||||||
|
:members:
|
|
@ -0,0 +1,5 @@
|
||||||
|
colossalai.engine.amp.grad\_scaler
|
||||||
|
==================================
|
||||||
|
|
||||||
|
.. automodule:: colossalai.engine.amp.grad_scaler
|
||||||
|
:members:
|
|
@ -0,0 +1,12 @@
|
||||||
|
colossalai.engine.amp
|
||||||
|
=====================
|
||||||
|
|
||||||
|
.. automodule:: colossalai.engine.amp
|
||||||
|
:members:
|
||||||
|
|
||||||
|
|
||||||
|
.. toctree::
|
||||||
|
:maxdepth: 2
|
||||||
|
|
||||||
|
colossalai.engine.amp.amp_type
|
||||||
|
colossalai.engine.amp.grad_scaler
|
|
@ -1,5 +0,0 @@
|
||||||
colossalai.engine.amp\_type
|
|
||||||
===========================
|
|
||||||
|
|
||||||
.. automodule:: colossalai.engine.amp_type
|
|
||||||
:members:
|
|
|
@ -7,11 +7,6 @@ colossalai.engine
|
||||||
.. toctree::
|
.. toctree::
|
||||||
:maxdepth: 2
|
:maxdepth: 2
|
||||||
|
|
||||||
|
colossalai.engine.amp
|
||||||
colossalai.engine.gradient_handler
|
colossalai.engine.gradient_handler
|
||||||
colossalai.engine.schedule
|
colossalai.engine.schedule
|
||||||
|
|
||||||
|
|
||||||
.. toctree::
|
|
||||||
:maxdepth: 2
|
|
||||||
|
|
||||||
colossalai.engine.amp_type
|
|
||||||
|
|
|
@ -21,7 +21,6 @@ colossalai
|
||||||
.. toctree::
|
.. toctree::
|
||||||
:maxdepth: 2
|
:maxdepth: 2
|
||||||
|
|
||||||
colossalai.checkpointing
|
|
||||||
colossalai.constants
|
colossalai.constants
|
||||||
colossalai.core
|
colossalai.core
|
||||||
colossalai.initialize
|
colossalai.initialize
|
||||||
|
|
|
@ -0,0 +1,5 @@
|
||||||
|
colossalai.utils.checkpointing
|
||||||
|
==============================
|
||||||
|
|
||||||
|
.. automodule:: colossalai.utils.checkpointing
|
||||||
|
:members:
|
|
@ -9,6 +9,7 @@ colossalai.utils
|
||||||
:maxdepth: 2
|
:maxdepth: 2
|
||||||
|
|
||||||
colossalai.utils.activation_checkpoint
|
colossalai.utils.activation_checkpoint
|
||||||
|
colossalai.utils.checkpointing
|
||||||
colossalai.utils.common
|
colossalai.utils.common
|
||||||
colossalai.utils.cuda
|
colossalai.utils.cuda
|
||||||
colossalai.utils.memory
|
colossalai.utils.memory
|
||||||
|
|
|
@ -17,38 +17,40 @@ parallel = dict(
|
||||||
)
|
)
|
||||||
```
|
```
|
||||||
|
|
||||||
The name of the dictionary variable should be **parallel**. All the arguments even **parallel** itself are optional and data,
|
The name of the dictionary variable should be **parallel**. All the arguments even **parallel** itself are optional and
|
||||||
pipeline, tensor parallel size will be set to defaulted value 1. The value of data, pipeline and tensor can be a int
|
data, pipeline, tensor parallel size will be set to defaulted value 1. The value of data, pipeline and tensor can be a
|
||||||
representing the size of specific parallel dimension or a dictionary with a key called "size". The key "mode"
|
int representing the size of specific parallel dimension or a dictionary with a key called "size". The key "mode"
|
||||||
represents the way of tensor parallelism.
|
represents the way of tensor parallelism.
|
||||||
|
|
||||||
## Data Parallel
|
## Data Parallel
|
||||||
|
|
||||||
Data parallel is the most common way to distribute your training task by splitting data into several shards and train
|
Data parallel is the most common way to distribute your training task by splitting data into several shards and train on
|
||||||
on a single shard on each device. The configuration for data parallel is detected automatically and set for you. You do
|
a single shard on each device. The configuration for data parallel is detected automatically and set for you. You do not
|
||||||
not have to explicitly set them in your configurations. When data parallel size is larger than 1, Colossal-AI automatically
|
have to explicitly set them in your configurations. When data parallel size is larger than 1, Colossal-AI automatically
|
||||||
adds the distributed data sampler to the dataloader to shard the dataset.
|
adds the distributed data sampler to the dataloader to shard the dataset.
|
||||||
|
|
||||||
## 1D, 2D, 2.5D and 3D Parallel
|
## 1D, 2D, 2.5D and 3D Parallel
|
||||||
|
|
||||||
To enable hybrid parallelism, we provide an array of tensor parallelism. We provide the list of papers which match each
|
To enable hybrid parallelism, we provide an array of tensor parallelism. We provide the list of papers which match each
|
||||||
tensor parallel method. These parallel modes need to work with the distributed layers provided by Colossal-AI.
|
tensor parallel method. These parallel modes need to work with the distributed layers provided by Colossal-AI.
|
||||||
- 1D: [Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism](https://arxiv.org/abs/1909.08053)
|
|
||||||
|
-
|
||||||
|
1D: [Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism](https://arxiv.org/abs/1909.08053)
|
||||||
|
|
||||||
- 2D: [An Efficient 2D Method for Training Super-Large Deep Learning Models](https://arxiv.org/abs/2104.05343)
|
- 2D: [An Efficient 2D Method for Training Super-Large Deep Learning Models](https://arxiv.org/abs/2104.05343)
|
||||||
2D parallel relies on the SUMMA matrix multiplication algorithm and splits the input data,
|
2D parallel relies on the SUMMA matrix multiplication algorithm and splits the input data, model weights and layer
|
||||||
model weights and layer outputs along two different dimensions. The tensor chunks are distributed over a 2D mesh of $P = N^2$
|
outputs along two different dimensions. The tensor chunks are distributed over a 2D mesh of $P = N^2$ devices where
|
||||||
devices where $N$ is the number of tensor chunks in a single dimension.
|
$N$ is the number of tensor chunks in a single dimension.
|
||||||
|
|
||||||
- 2.5D: [2.5-dimensional distributed model training](https://arxiv.org/abs/2105.14500)
|
- 2.5D: [2.5-dimensional distributed model training](https://arxiv.org/abs/2105.14500)
|
||||||
Inspired by the 2.5D matrix multiplication algorithm, 2.5D parallel introduces a novel tensor parallelism which further
|
Inspired by the 2.5D matrix multiplication algorithm, 2.5D parallel introduces a novel tensor parallelism which
|
||||||
parallelizes 2D tensor parallelism. An amount of $P = N^2 ∗ d$ processors are arranged into $d$ layers,
|
further parallelizes 2D tensor parallelism. An amount of $P = N^2 ∗ d$ processors are arranged into $d$ layers, where
|
||||||
where each layer performs matrix multiplication operations independently with a dimension $N$.
|
each layer performs matrix multiplication operations independently with a dimension $N$.
|
||||||
|
|
||||||
- 3D: [Maximizing Parallelism in Distributed Training for Huge Neural Networks](https://arxiv.org/abs/2105.14450)
|
- 3D: [Maximizing Parallelism in Distributed Training for Huge Neural Networks](https://arxiv.org/abs/2105.14450)
|
||||||
We also introduce a 3D tensor parallelism that parallelizes neural networks on a 3D processor cube. This method achieves
|
We also introduce a 3D tensor parallelism that parallelizes neural networks on a 3D processor cube. This method
|
||||||
the optimal, $O(P^{1/3})$ communication overhead on $P$ processors, while both computation and memory usage are evenly distributed
|
achieves the optimal, $O(P^{1/3})$ communication overhead on $P$ processors, while both computation and memory usage
|
||||||
through optimized load balancing of parameters as well as activations.
|
are evenly distributed through optimized load balancing of parameters as well as activations.
|
||||||
|
|
||||||
```python
|
```python
|
||||||
# 1D parallel
|
# 1D parallel
|
||||||
|
@ -78,12 +80,12 @@ parallel = dict(
|
||||||
|
|
||||||
## Pipeline Parallel (experimental)
|
## Pipeline Parallel (experimental)
|
||||||
|
|
||||||
Pipeline parallelism is to split the model into several partitions by layer. For example, let's assume we have a simple
|
Pipeline parallelism is to split the model into several partitions by layer. For example, let's assume we have a simple
|
||||||
model which consists of two linear layer. We have two GPUs, and we can allocate the first linear layer to the first GPU
|
model which consists of two linear layer. We have two GPUs, and we can allocate the first linear layer to the first GPU
|
||||||
and the second layer to the second GPU. This example of course wastes the computing resources and is only to demonstrate
|
and the second layer to the second GPU. This example of course wastes the computing resources and is only to demonstrate
|
||||||
the idea of pipeline parallelism.
|
the idea of pipeline parallelism.
|
||||||
|
|
||||||
As PyTorch is based on dynamic computation graph, the computation flow is not known until execution. To support pipeline
|
As PyTorch is based on dynamic computation graph, the computation flow is not known until execution. To support pipeline
|
||||||
parallelism in PyTorch, you may need to add one more attribute, `layers_cfg` in your model class which tells Colossal-AI
|
parallelism in PyTorch, you may need to add one more attribute, `layers_cfg` in your model class which tells Colossal-AI
|
||||||
the sequence of execution. One example you can refer is `colossalai.nn.model.VanillaResNet`.
|
the sequence of execution. One example you can refer is `colossalai.nn.model.VanillaResNet`.
|
||||||
|
|
||||||
|
@ -192,9 +194,9 @@ class VanillaResNet(BaseModel):
|
||||||
]
|
]
|
||||||
```
|
```
|
||||||
|
|
||||||
You can set the number of pipeline stages in your configuration file. When pipeline size is larger than 1, Colossal-AI
|
You can set the number of pipeline stages in your configuration file. When pipeline size is larger than 1, Colossal-AI
|
||||||
will automatically creates the pipeline schedule which defines the forward and backward step. You can specify how many microbatches
|
will automatically creates the pipeline schedule which defines the forward and backward step. You can specify how many
|
||||||
to run in each step in the `schedule` configuration.
|
microbatches to run in each step in the `schedule` configuration.
|
||||||
|
|
||||||
```python
|
```python
|
||||||
parallel = dict(
|
parallel = dict(
|
||||||
|
@ -206,10 +208,11 @@ schedule = dict(
|
||||||
num_microbatches = 4 # set the number of microbatches per step
|
num_microbatches = 4 # set the number of microbatches per step
|
||||||
)
|
)
|
||||||
```
|
```
|
||||||
|
|
||||||
This feature is still in development and is only experimental for now.
|
This feature is still in development and is only experimental for now.
|
||||||
|
|
||||||
## Sequence Parallel (experimental)
|
## Sequence Parallel (experimental)
|
||||||
|
|
||||||
Sequence parallel is to support long-sequence modelling such as document-level text understanding and medical imaging.
|
Sequence parallel is to support long-sequence modelling such as document-level text understanding and medical imaging.
|
||||||
This method is proposed in [Sequence Parallelism: Making 4D Parallelism Possible](https://arxiv.org/abs/2105.13120).
|
This method is proposed in [Sequence Parallelism: Making 4D Parallelism Possible](https://arxiv.org/abs/2105.13120).
|
||||||
This feature is still in development and is only experimental for now.
|
This feature is still in development and is only experimental for now.
|
||||||
|
|
|
@ -1,8 +1,8 @@
|
||||||
# Quick demo
|
# Quick demo
|
||||||
|
|
||||||
Colossal-AI is an integrated large-scale deep learning system with efficient parallelization techniques. The system
|
Colossal-AI is an integrated large-scale deep learning system with efficient parallelization techniques. The system can
|
||||||
can accelerate model training on distributed systems with multiple GPUs by applying parallelization techniques. The
|
accelerate model training on distributed systems with multiple GPUs by applying parallelization techniques. The system
|
||||||
system can also run on systems with only one GPU. Quick demos showing how to use Colossal-AI are given below.
|
can also run on systems with only one GPU. Quick demos showing how to use Colossal-AI are given below.
|
||||||
|
|
||||||
## Single GPU
|
## Single GPU
|
||||||
|
|
||||||
|
@ -32,25 +32,17 @@ realizes the training process.
|
||||||
```python
|
```python
|
||||||
import colossalai
|
import colossalai
|
||||||
from colossalai.core import global_context as gpc
|
from colossalai.core import global_context as gpc
|
||||||
from colossalai.engine import Engine
|
|
||||||
from colossalai.logging import get_global_dist_logger
|
from colossalai.logging import get_global_dist_logger
|
||||||
from colossalai.trainer import Trainer
|
from colossalai.trainer import Trainer
|
||||||
|
|
||||||
|
|
||||||
def run_trainer():
|
def run_trainer():
|
||||||
model, train_dataloader, test_dataloader, criterion, optimizer, schedule, lr_scheduler = colossalai.initialize()
|
engine, train_dataloader, test_dataloader = colossalai.initialize()
|
||||||
logger = get_global_dist_logger()
|
logger = get_global_dist_logger()
|
||||||
schedule.data_sync = False
|
|
||||||
engine = Engine(
|
|
||||||
model=model,
|
|
||||||
criterion=criterion,
|
|
||||||
optimizer=optimizer,
|
|
||||||
lr_scheduler=lr_scheduler,
|
|
||||||
schedule=schedule
|
|
||||||
)
|
|
||||||
logger.info("engine is built", ranks=[0])
|
logger.info("engine is built", ranks=[0])
|
||||||
|
|
||||||
trainer = Trainer(engine=engine,
|
trainer = Trainer(engine=engine,
|
||||||
hooks_cfg=gpc.config.hooks,
|
|
||||||
verbose=True)
|
verbose=True)
|
||||||
logger.info("trainer is built", ranks=[0])
|
logger.info("trainer is built", ranks=[0])
|
||||||
|
|
||||||
|
@ -58,11 +50,13 @@ def run_trainer():
|
||||||
trainer.fit(
|
trainer.fit(
|
||||||
train_dataloader=train_dataloader,
|
train_dataloader=train_dataloader,
|
||||||
test_dataloader=test_dataloader,
|
test_dataloader=test_dataloader,
|
||||||
max_epochs=gpc.config.num_epochs,
|
epochs=gpc.config.num_epochs,
|
||||||
|
hooks_cfg=gpc.config.hooks,
|
||||||
display_progress=True,
|
display_progress=True,
|
||||||
test_interval=2
|
test_interval=2
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
run_trainer()
|
run_trainer()
|
||||||
```
|
```
|
||||||
|
@ -72,9 +66,9 @@ Zoo. The detailed substitution process is elaborated [here](model.md).
|
||||||
|
|
||||||
## Features
|
## Features
|
||||||
|
|
||||||
Colossal-AI provides a collection of parallel training components for you. We aim to support you with your development of
|
Colossal-AI provides a collection of parallel training components for you. We aim to support you with your development
|
||||||
distributed deep learning models just like how you write single-GPU deep learning models. We provide friendly tools to
|
of distributed deep learning models just like how you write single-GPU deep learning models. We provide friendly tools
|
||||||
kickstart distributed training in a few lines.
|
to kickstart distributed training in a few lines.
|
||||||
|
|
||||||
- [Data Parallelism](parallelization.md)
|
- [Data Parallelism](parallelization.md)
|
||||||
- [Pipeline Parallelism](parallelization.md)
|
- [Pipeline Parallelism](parallelization.md)
|
||||||
|
|
|
@ -4,40 +4,36 @@ Colossal-AI是一个大规模深度学习系统,其中包含高效的并行技
|
||||||
|
|
||||||
## 单GPU系统
|
## 单GPU系统
|
||||||
|
|
||||||
在带有GPU的非分布式系统上进行模型训练时,Colossal-AI可以达到当前的基线效率。[这里](https://colab.research.google.com/drive/1fJnqqFzPuzZ_kn1lwCpG2nh3l2ths0KE?usp=sharing#scrollTo=cQ_y7lBG09LS)我们给出一个Google Colab示例展现如何使用Colossal-AI与CIFAR10数据集在非分布式系统上训练一个LeNet模型。
|
在带有GPU的非分布式系统上进行模型训练时,Colossal-AI可以达到当前的基线效率。[这里](https://colab.research.google.com/drive/1fJnqqFzPuzZ_kn1lwCpG2nh3l2ths0KE?usp=sharing#scrollTo=cQ_y7lBG09LS)我们给出一个Google
|
||||||
|
Colab示例展现如何使用Colossal-AI与CIFAR10数据集在非分布式系统上训练一个LeNet模型。
|
||||||
|
|
||||||
## 多GPU系统
|
## 多GPU系统
|
||||||
|
|
||||||
在多GPU的分布式系统上训练深度学习模型时,Colossal-AI可以使用高效的并行技术来显著地加速训练过程,这些技术将在下面的[并行技术](parallelization.md)章节中被详述。下面的代码将在拥有四个GPU的分布式系统上训练一个ViT模型,其中`HOST`变量为您分布式系统的IP地址。请注意下面的代码使用了[Slurm](https://slurm.schedmd.com/documentation.html)作业调度系统。
|
在多GPU的分布式系统上训练深度学习模型时,Colossal-AI可以使用高效的并行技术来显著地加速训练过程,这些技术将在下面的[并行技术](parallelization.md)
|
||||||
|
章节中被详述。下面的代码将在拥有四个GPU的分布式系统上训练一个ViT模型,其中`HOST`
|
||||||
|
变量为您分布式系统的IP地址。请注意下面的代码使用了[Slurm](https://slurm.schedmd.com/documentation.html)作业调度系统。
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
HOST=xxx.xxx.xxx.xxx srun ./scripts/slurm_dist_train.sh ./examples/run_trainer.py ./configs/vit/vit_2d.py
|
HOST=xxx.xxx.xxx.xxx srun ./scripts/slurm_dist_train.sh ./examples/run_trainer.py ./configs/vit/vit_2d.py
|
||||||
```
|
```
|
||||||
|
|
||||||
`./configs/vit/vit_2d.py`是一个[配置文件](config.md),Colossal-AI使用配置文件来定义训练过程中需要用到的参数,比如模型类型、数据集、以及优化器、学习率调度器等。您可以通过编写配置文件的方式来训练不同的模型。`./examples/run_trainer.py`是一个标准的训练脚本,具体代码已经附在下面。该脚本可以读入配置文件中的训练参数并训练模型。
|
`./configs/vit/vit_2d.py`是一个[配置文件](config.md)
|
||||||
|
,Colossal-AI使用配置文件来定义训练过程中需要用到的参数,比如模型类型、数据集、以及优化器、学习率调度器等。您可以通过编写配置文件的方式来训练不同的模型。`./examples/run_trainer.py`
|
||||||
|
是一个标准的训练脚本,具体代码已经附在下面。该脚本可以读入配置文件中的训练参数并训练模型。
|
||||||
|
|
||||||
```python
|
```python
|
||||||
import colossalai
|
import colossalai
|
||||||
from colossalai.core import global_context as gpc
|
from colossalai.core import global_context as gpc
|
||||||
from colossalai.engine import Engine
|
|
||||||
from colossalai.logging import get_global_dist_logger
|
from colossalai.logging import get_global_dist_logger
|
||||||
from colossalai.trainer import Trainer
|
from colossalai.trainer import Trainer
|
||||||
|
|
||||||
|
|
||||||
def run_trainer():
|
def run_trainer():
|
||||||
model, train_dataloader, test_dataloader, criterion, optimizer, schedule, lr_scheduler = colossalai.initialize()
|
engine, train_dataloader, test_dataloader = colossalai.initialize()
|
||||||
logger = get_global_dist_logger()
|
logger = get_global_dist_logger()
|
||||||
schedule.data_sync = False
|
|
||||||
engine = Engine(
|
|
||||||
model=model,
|
|
||||||
criterion=criterion,
|
|
||||||
optimizer=optimizer,
|
|
||||||
lr_scheduler=lr_scheduler,
|
|
||||||
schedule=schedule
|
|
||||||
)
|
|
||||||
logger.info("engine is built", ranks=[0])
|
logger.info("engine is built", ranks=[0])
|
||||||
|
|
||||||
trainer = Trainer(engine=engine,
|
trainer = Trainer(engine=engine,
|
||||||
hooks_cfg=gpc.config.hooks,
|
|
||||||
verbose=True)
|
verbose=True)
|
||||||
logger.info("trainer is built", ranks=[0])
|
logger.info("trainer is built", ranks=[0])
|
||||||
|
|
||||||
|
@ -45,11 +41,13 @@ def run_trainer():
|
||||||
trainer.fit(
|
trainer.fit(
|
||||||
train_dataloader=train_dataloader,
|
train_dataloader=train_dataloader,
|
||||||
test_dataloader=test_dataloader,
|
test_dataloader=test_dataloader,
|
||||||
max_epochs=gpc.config.num_epochs,
|
epochs=gpc.config.num_epochs,
|
||||||
|
hooks_cfg=gpc.config.hooks,
|
||||||
display_progress=True,
|
display_progress=True,
|
||||||
test_interval=2
|
test_interval=2
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
run_trainer()
|
run_trainer()
|
||||||
```
|
```
|
||||||
|
|
|
@ -2,9 +2,9 @@
|
||||||
|
|
||||||
## Build your engine
|
## Build your engine
|
||||||
|
|
||||||
To better understand how `Engine` class works, let's start from the conception of the process function in common engines. The process function
|
To better understand how `Engine` class works, let's start from the conception of the process function in common
|
||||||
usually controls the behavior over a batch of a dataset, `Engine` class just controls the process function. Here we give a standard process
|
engines. The process function usually controls the behavior over a batch of a dataset, `Engine` class just controls the
|
||||||
function in the following code block.
|
process function. Here we give a standard process function in the following code block.
|
||||||
|
|
||||||
```python
|
```python
|
||||||
def process_function(dataloader, model, criterion, optim):
|
def process_function(dataloader, model, criterion, optim):
|
||||||
|
@ -16,32 +16,33 @@ def process_function(dataloader, model, criterion, optim):
|
||||||
optim.setp()
|
optim.setp()
|
||||||
```
|
```
|
||||||
|
|
||||||
In `ignite.engine` or `keras.engine`, the process function is always provided by users. However, it is tricky for users to write their own process
|
In `ignite.engine` or `keras.engine`, the process function is always provided by users. However, it is tricky for users
|
||||||
functions for pipeline parallelism. Aiming at offering accessible hybrid parallelism for users, we provide the powerful `Engine` class. This class
|
to write their own process functions for pipeline parallelism. Aiming at offering accessible hybrid parallelism for
|
||||||
enables pipeline parallelism and offers one-forward-one-backward non-interleaving strategy. Also, you can use pre-defined learning rate scheduler
|
users, we provide the powerful `Engine` class. This class enables pipeline parallelism and offers
|
||||||
in the `Engine` class to adjust learning rate during training.
|
one-forward-one-backward non-interleaving strategy. Also, you can use pre-defined learning rate scheduler in
|
||||||
|
the `Engine` class to adjust learning rate during training.
|
||||||
|
|
||||||
In order to build your engine, just set variables `model`, `criterion`, `optimizer`, `lr_scheduler` and `schedule`. The following code block provides
|
In order to build your engine, just set variables `model`, `criterion`, `optimizer`, `lr_scheduler` and `schedule`. The
|
||||||
an example.
|
following code block provides an example. **The engine is automatically created from the config file for you if you
|
||||||
|
start with `colossalai.initialize`.**
|
||||||
|
|
||||||
```python
|
```python
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torchvision.models as models
|
import torchvision.models as models
|
||||||
import colossalai
|
import colossalai
|
||||||
|
from colossalai.engine import Engine
|
||||||
|
|
||||||
model = models.resnet18()
|
model = models.resnet18()
|
||||||
criterion = nn.CrossEntropyLoss()
|
criterion = nn.CrossEntropyLoss()
|
||||||
optimizer = torch.optim.Adam(model)
|
optimizer = torch.optim.Adam(model.parameters())
|
||||||
lr_scheduler = colossalai.nn.lr_scheduler.CosineAnnealingLR(optimizer, 1000)
|
schedule = colossalai.engine.NoPipelineSchedule()
|
||||||
schedule = colossalai.engine.schedule.NoPipelineSchedule()
|
|
||||||
|
|
||||||
MyEngine = Engine(
|
MyEngine = Engine(
|
||||||
model=model,
|
model=model,
|
||||||
criterion=criterion,
|
criterion=criterion,
|
||||||
optimizer=optimizer,
|
optimizer=optimizer,
|
||||||
lr_scheduler=lr_scheduler,
|
step_schedule=schedule
|
||||||
schedule=schedule
|
|
||||||
)
|
)
|
||||||
```
|
```
|
||||||
|
|
||||||
|
@ -51,21 +52,24 @@ More information regarding the class can be found in the API references.
|
||||||
|
|
||||||
### Overview
|
### Overview
|
||||||
|
|
||||||
To learn how to customize a trainer which meets your needs, let's first give a look at the `Trainer` class. We highly recommend that you read *Get Started*
|
To learn how to customize a trainer which meets your needs, let's first give a look at the `Trainer` class. We highly
|
||||||
|
recommend that you read *Get Started*
|
||||||
section and *Build your engine* first.
|
section and *Build your engine* first.
|
||||||
|
|
||||||
The `Trainer` class enables researchers and engineers to use our system more conveniently. Instead of having to write your own scripts, you can simply
|
The `Trainer` class enables researchers and engineers to use our system more conveniently. Instead of having to write
|
||||||
construct your own trainer by calling the `Trainer` class, just like what we did in the following code block.
|
your own scripts, you can simply construct your own trainer by calling the `Trainer` class, just like what we did in the
|
||||||
|
following code block.
|
||||||
|
|
||||||
```python
|
```python
|
||||||
MyTrainer = Trainer(MyEngine)
|
MyTrainer = Trainer(my_engine)
|
||||||
```
|
```
|
||||||
|
|
||||||
After that, you can use the `fit` method to train or evaluate your model. In order to make our `Trainer` class even more powerful, we incorporate a set of
|
After that, you can use the `fit` method to train or evaluate your model. In order to make our `Trainer` class even more
|
||||||
handy tools to the class. For example, you can monitor or record the running states and metrics which indicate the current performance of the model. These
|
powerful, we incorporate a set of handy tools to the class. For example, you can monitor or record the running states
|
||||||
functions are realized by hooks. The `BasicHook` class allows you to execute your hook functions at specified time. We have already created some practical
|
and metrics which indicate the current performance of the model. These functions are realized by hooks. The `BasicHook`
|
||||||
hooks for you, as listed below. What you need to do is just picking the right ones which suit your needs. Detailed descriptions of the class can be found
|
class allows you to execute your hook functions at specified time. We have already created some practical hooks for you,
|
||||||
in the API references.
|
as listed below. What you need to do is just picking the right ones which suit your needs. Detailed descriptions of the
|
||||||
|
class can be found in the API references.
|
||||||
|
|
||||||
```python
|
```python
|
||||||
hooks = [
|
hooks = [
|
||||||
|
@ -80,18 +84,21 @@ hooks = [
|
||||||
]
|
]
|
||||||
```
|
```
|
||||||
|
|
||||||
These hook functions will record metrics, elapsed time and memory usage and write them to log after each epoch. Besides, they print the current loss and
|
These hook functions will record metrics, elapsed time and memory usage and write them to log after each epoch. Besides,
|
||||||
accuracy to let users monitor the performance of the model.
|
they print the current loss and accuracy to let users monitor the performance of the model.
|
||||||
|
|
||||||
### Hook
|
### Hook
|
||||||
|
|
||||||
If you have your specific needs, feel free to extend our `BaseHook` class to add your own functions, or our `MetricHook` class to write a metric collector.
|
If you have your specific needs, feel free to extend our `BaseHook` class to add your own functions, or our `MetricHook`
|
||||||
These hook functions can be called at twelve timing in the trainer's life cycle. Besides, you can define the priorities of all hooks to arrange the execution order of them.
|
class to write a metric collector. These hook functions can be called at twelve timing in the trainer's life cycle.
|
||||||
More information can be found in the API references.
|
Besides, you can define the priorities of all hooks to arrange the execution order of them. More information can be
|
||||||
|
found in the API references.
|
||||||
|
|
||||||
### Metric
|
### Metric
|
||||||
|
|
||||||
You can write your own metrics by extending our `Metric` class. It should be used with the `MetricHook` class. When your write your own metric hooks, please set
|
You can write your own metrics by extending our `Metric` class. It should be used with the `MetricHook` class. When your
|
||||||
the priority carefully and make sure the hook is called before other hooks which might require the results of the metric hook.
|
write your own metric hooks, please set the priority carefully and make sure the hook is called before other hooks which
|
||||||
|
might require the results of the metric hook.
|
||||||
|
|
||||||
We've already provided some metric hooks and we store metric objects in `runner.states['metrics']`. It is a dictionary and metrics can be accessed by their names.
|
We've already provided some metric hooks and we store metric objects in `runner.states['metrics']`. It is a dictionary
|
||||||
|
and metrics can be accessed by their names.
|
||||||
|
|
|
@ -14,28 +14,30 @@ def process_function(dataloader, model, criterion, optim):
|
||||||
optim.setp()
|
optim.setp()
|
||||||
```
|
```
|
||||||
|
|
||||||
在`ignite.engine`与`keras.engine`中,进程函数需要由用户提供,然而,用户很难为流水线并行编写进程函数。为了向用户提供方便的混合并行,我们提供了具备强大功能的`Engine`类,该类支持流水线并行,并提供前向传播后向传播不交织的策略。同时,您可以在`Engine`类中使用您事先定义好的学习率调度器来在训练过程中调整学习率。
|
在`ignite.engine`与`keras.engine`中,进程函数需要由用户提供,然而,用户很难为流水线并行编写进程函数。为了向用户提供方便的混合并行,我们提供了具备强大功能的`Engine`
|
||||||
|
类,该类支持流水线并行,并提供前向传播后向传播不交织的策略。同时,您可以在`Engine`类中使用您事先定义好的学习率调度器来在训练过程中调整学习率。
|
||||||
|
|
||||||
您在构造引擎时只需要定义`model`、`criterion`、`optimizer`、`lr_scheduler`与`schedule`等变量即可,下面的代码块给出了一个这样的例子。
|
您在构造引擎时只需要定义`model`、`criterion`、`optimizer`、`lr_scheduler`与`schedule`等变量即可,下面的代码块给出了一个这样的例子。
|
||||||
|
**如果你使用`colossalai.initialize`的话,engine会从config文件里自动构建。**
|
||||||
|
|
||||||
```python
|
```python
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torchvision.models as models
|
import torchvision.models as models
|
||||||
import colossalai
|
import colossalai
|
||||||
|
from colossalai.engine import Engine
|
||||||
|
|
||||||
model = models.resnet18()
|
model = models.resnet18()
|
||||||
criterion = nn.CrossEntropyLoss()
|
criterion = nn.CrossEntropyLoss()
|
||||||
optimizer = torch.optim.Adam(model)
|
optimizer = torch.optim.Adam(model)
|
||||||
lr_scheduler = colossalai.nn.lr_scheduler.CosineAnnealingLR(optimizer, 1000)
|
lr_scheduler = colossalai.nn.lr_scheduler.CosineAnnealingLR(optimizer, 1000)
|
||||||
schedule = colossalai.engine.schedule.NoPipelineSchedule()
|
schedule = colossalai.engine.NoPipelineSchedule()
|
||||||
|
|
||||||
MyEngine = Engine(
|
MyEngine = Engine(
|
||||||
model=model,
|
model=model,
|
||||||
criterion=criterion,
|
criterion=criterion,
|
||||||
optimizer=optimizer,
|
optimizer=optimizer,
|
||||||
lr_scheduler=lr_scheduler,
|
step_schedule=schedule
|
||||||
schedule=schedule
|
|
||||||
)
|
)
|
||||||
```
|
```
|
||||||
|
|
||||||
|
@ -48,10 +50,12 @@ MyEngine = Engine(
|
||||||
`Trainer`类旨在让科研工作者和工程师更加方便地使用我们的系统,您不需要自己写脚本,只需要调用`Trainer`类来构造您的训练器即可,就像下面的代码块中所做的。
|
`Trainer`类旨在让科研工作者和工程师更加方便地使用我们的系统,您不需要自己写脚本,只需要调用`Trainer`类来构造您的训练器即可,就像下面的代码块中所做的。
|
||||||
|
|
||||||
```python
|
```python
|
||||||
MyTrainer = Trainer(MyEngine)
|
MyTrainer = Trainer(my_trainer)
|
||||||
```
|
```
|
||||||
|
|
||||||
在此之后,您可以使用`fit`方法来训练或调用您的模型。除此之外,为了让我们的`Trainer`类拥有更强大的功能,我们加入了一系列方便您使用的工具。例如,您可以在训练过程中持续监测并记录模型目前的运行状态和表现,这些功能都是通过钩子函数来实现的。我们提供的`BasicHook`类让您可以在指定时间执行您的钩子函数。如下方的代码块所示,我们事先为您定义好了一些实用的钩子函数,您需要做的就是找到符合您需求的钩子函数。更多该类的相关信息可以在API信息中找到。
|
在此之后,您可以使用`fit`方法来训练或调用您的模型。除此之外,为了让我们的`Trainer`
|
||||||
|
类拥有更强大的功能,我们加入了一系列方便您使用的工具。例如,您可以在训练过程中持续监测并记录模型目前的运行状态和表现,这些功能都是通过钩子函数来实现的。我们提供的`BasicHook`
|
||||||
|
类让您可以在指定时间执行您的钩子函数。如下方的代码块所示,我们事先为您定义好了一些实用的钩子函数,您需要做的就是找到符合您需求的钩子函数。更多该类的相关信息可以在API信息中找到。
|
||||||
|
|
||||||
```python
|
```python
|
||||||
hooks = [
|
hooks = [
|
||||||
|
@ -70,7 +74,8 @@ hooks = [
|
||||||
|
|
||||||
### 钩子函数
|
### 钩子函数
|
||||||
|
|
||||||
如果您有个性化需求,您可以继承我们的`BaseHook`类并添加您的钩子函数,或者继承我们的`MetricHook`来编写您需要的度量标准。这些钩子函数可以在`Trainer`生命周期的12个时间点被执行。更多该类的相关信息可以在API信息中找到。
|
如果您有个性化需求,您可以继承我们的`BaseHook`类并添加您的钩子函数,或者继承我们的`MetricHook`来编写您需要的度量标准。这些钩子函数可以在`Trainer`
|
||||||
|
生命周期的12个时间点被执行。更多该类的相关信息可以在API信息中找到。
|
||||||
|
|
||||||
### 度量标准
|
### 度量标准
|
||||||
|
|
||||||
|
|
|
@ -1,370 +1,370 @@
|
||||||
{
|
{
|
||||||
"nbformat": 4,
|
"nbformat": 4,
|
||||||
"nbformat_minor": 0,
|
"nbformat_minor": 0,
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"colab": {
|
"colab": {
|
||||||
"name": "colossal_cifar_demo.ipynb",
|
"name": "colossal_cifar_demo.ipynb",
|
||||||
"provenance": []
|
"provenance": []
|
||||||
},
|
|
||||||
"kernelspec": {
|
|
||||||
"name": "python3",
|
|
||||||
"display_name": "Python 3"
|
|
||||||
},
|
|
||||||
"language_info": {
|
|
||||||
"name": "python"
|
|
||||||
},
|
|
||||||
"accelerator": "GPU"
|
|
||||||
},
|
},
|
||||||
"cells": [
|
"kernelspec": {
|
||||||
{
|
"name": "python3",
|
||||||
"cell_type": "markdown",
|
"display_name": "Python 3"
|
||||||
"metadata": {
|
},
|
||||||
"id": "uhrbvVEh2iJd"
|
"language_info": {
|
||||||
},
|
"name": "python"
|
||||||
"source": [
|
},
|
||||||
"# Train an image classifier\n"
|
"accelerator": "GPU"
|
||||||
]
|
},
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {
|
||||||
|
"id": "uhrbvVEh2iJd"
|
||||||
|
},
|
||||||
|
"source": [
|
||||||
|
"# Train an image classifier\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"metadata": {
|
||||||
|
"colab": {
|
||||||
|
"base_uri": "https://localhost:8080/"
|
||||||
},
|
},
|
||||||
|
"id": "vP7LvCpG23a2",
|
||||||
|
"outputId": "b37f7203-8a02-4736-c527-603f2bb34d7d"
|
||||||
|
},
|
||||||
|
"source": [
|
||||||
|
"!pip install ColossalAI deepspeed"
|
||||||
|
],
|
||||||
|
"execution_count": null,
|
||||||
|
"outputs": [
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"output_type": "stream",
|
||||||
"metadata": {
|
"name": "stdout",
|
||||||
"colab": {
|
"text": [
|
||||||
"base_uri": "https://localhost:8080/"
|
"Requirement already satisfied: ColossalAI in /usr/local/lib/python3.7/dist-packages (0.1)\n",
|
||||||
},
|
"Requirement already satisfied: deepspeed in /usr/local/lib/python3.7/dist-packages (0.5.4)\n",
|
||||||
"id": "vP7LvCpG23a2",
|
"Requirement already satisfied: packaging in /usr/local/lib/python3.7/dist-packages (from deepspeed) (21.0)\n",
|
||||||
"outputId": "b37f7203-8a02-4736-c527-603f2bb34d7d"
|
"Requirement already satisfied: triton in /usr/local/lib/python3.7/dist-packages (from deepspeed) (1.1.1)\n",
|
||||||
},
|
"Requirement already satisfied: tqdm in /usr/local/lib/python3.7/dist-packages (from deepspeed) (4.62.3)\n",
|
||||||
"source": [
|
"Requirement already satisfied: numpy in /usr/local/lib/python3.7/dist-packages (from deepspeed) (1.19.5)\n",
|
||||||
"!pip install ColossalAI deepspeed"
|
"Requirement already satisfied: tensorboardX==1.8 in /usr/local/lib/python3.7/dist-packages (from deepspeed) (1.8)\n",
|
||||||
],
|
"Requirement already satisfied: ninja in /usr/local/lib/python3.7/dist-packages (from deepspeed) (1.10.2.2)\n",
|
||||||
"execution_count": null,
|
"Requirement already satisfied: torch in /usr/local/lib/python3.7/dist-packages (from deepspeed) (1.9.0+cu111)\n",
|
||||||
"outputs": [
|
"Requirement already satisfied: psutil in /usr/local/lib/python3.7/dist-packages (from deepspeed) (5.4.8)\n",
|
||||||
{
|
"Requirement already satisfied: protobuf>=3.2.0 in /usr/local/lib/python3.7/dist-packages (from tensorboardX==1.8->deepspeed) (3.17.3)\n",
|
||||||
"output_type": "stream",
|
"Requirement already satisfied: six in /usr/local/lib/python3.7/dist-packages (from tensorboardX==1.8->deepspeed) (1.15.0)\n",
|
||||||
"name": "stdout",
|
"Requirement already satisfied: pyparsing>=2.0.2 in /usr/local/lib/python3.7/dist-packages (from packaging->deepspeed) (2.4.7)\n",
|
||||||
"text": [
|
"Requirement already satisfied: typing-extensions in /usr/local/lib/python3.7/dist-packages (from torch->deepspeed) (3.7.4.3)\n",
|
||||||
"Requirement already satisfied: ColossalAI in /usr/local/lib/python3.7/dist-packages (0.1)\n",
|
"Requirement already satisfied: filelock in /usr/local/lib/python3.7/dist-packages (from triton->deepspeed) (3.3.0)\n"
|
||||||
"Requirement already satisfied: deepspeed in /usr/local/lib/python3.7/dist-packages (0.5.4)\n",
|
]
|
||||||
"Requirement already satisfied: packaging in /usr/local/lib/python3.7/dist-packages (from deepspeed) (21.0)\n",
|
|
||||||
"Requirement already satisfied: triton in /usr/local/lib/python3.7/dist-packages (from deepspeed) (1.1.1)\n",
|
|
||||||
"Requirement already satisfied: tqdm in /usr/local/lib/python3.7/dist-packages (from deepspeed) (4.62.3)\n",
|
|
||||||
"Requirement already satisfied: numpy in /usr/local/lib/python3.7/dist-packages (from deepspeed) (1.19.5)\n",
|
|
||||||
"Requirement already satisfied: tensorboardX==1.8 in /usr/local/lib/python3.7/dist-packages (from deepspeed) (1.8)\n",
|
|
||||||
"Requirement already satisfied: ninja in /usr/local/lib/python3.7/dist-packages (from deepspeed) (1.10.2.2)\n",
|
|
||||||
"Requirement already satisfied: torch in /usr/local/lib/python3.7/dist-packages (from deepspeed) (1.9.0+cu111)\n",
|
|
||||||
"Requirement already satisfied: psutil in /usr/local/lib/python3.7/dist-packages (from deepspeed) (5.4.8)\n",
|
|
||||||
"Requirement already satisfied: protobuf>=3.2.0 in /usr/local/lib/python3.7/dist-packages (from tensorboardX==1.8->deepspeed) (3.17.3)\n",
|
|
||||||
"Requirement already satisfied: six in /usr/local/lib/python3.7/dist-packages (from tensorboardX==1.8->deepspeed) (1.15.0)\n",
|
|
||||||
"Requirement already satisfied: pyparsing>=2.0.2 in /usr/local/lib/python3.7/dist-packages (from packaging->deepspeed) (2.4.7)\n",
|
|
||||||
"Requirement already satisfied: typing-extensions in /usr/local/lib/python3.7/dist-packages (from torch->deepspeed) (3.7.4.3)\n",
|
|
||||||
"Requirement already satisfied: filelock in /usr/local/lib/python3.7/dist-packages (from triton->deepspeed) (3.3.0)\n"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"metadata": {
|
|
||||||
"colab": {
|
|
||||||
"base_uri": "https://localhost:8080/"
|
|
||||||
},
|
|
||||||
"id": "UVKEurtS4SFS",
|
|
||||||
"outputId": "99fb6050-5da7-4f27-b4eb-9b3ccf830efb"
|
|
||||||
},
|
|
||||||
"source": [
|
|
||||||
"import colossalai\n",
|
|
||||||
"from colossalai.engine import Engine, NoPipelineSchedule\n",
|
|
||||||
"from colossalai.trainer import Trainer\n",
|
|
||||||
"from colossalai.context import Config\n",
|
|
||||||
"import torch"
|
|
||||||
],
|
|
||||||
"execution_count": null,
|
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"output_type": "stream",
|
|
||||||
"name": "stdout",
|
|
||||||
"text": [
|
|
||||||
"Please install apex to use FP16 Optimizer\n",
|
|
||||||
"Apex should be installed to use the FP16 optimizer\n",
|
|
||||||
"apex is required for mixed precision training\n"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"metadata": {
|
|
||||||
"id": "PpFfhNBD7NSn"
|
|
||||||
},
|
|
||||||
"source": [
|
|
||||||
"First, we should initialize distributed environment. Though we just use single GPU in this example, we still need initialize distributed environment for compatibility. We just consider the simplest case here, so we just set the number of parallel processes to 1."
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"metadata": {
|
|
||||||
"colab": {
|
|
||||||
"base_uri": "https://localhost:8080/"
|
|
||||||
},
|
|
||||||
"id": "8yF7Lc-K7NAS",
|
|
||||||
"outputId": "01312349-a8b0-4de4-9103-7d1b48e6cc36"
|
|
||||||
},
|
|
||||||
"source": [
|
|
||||||
"parallel_cfg = Config(dict(parallel=dict(\n",
|
|
||||||
" data=dict(size=1),\n",
|
|
||||||
" pipeline=dict(size=1),\n",
|
|
||||||
" tensor=dict(size=1, mode=None),\n",
|
|
||||||
")))\n",
|
|
||||||
"colossalai.init_dist(config=parallel_cfg,\n",
|
|
||||||
" local_rank=0,\n",
|
|
||||||
" world_size=1,\n",
|
|
||||||
" host='127.0.0.1',\n",
|
|
||||||
" port=8888,\n",
|
|
||||||
" backend='nccl')"
|
|
||||||
],
|
|
||||||
"execution_count": null,
|
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"output_type": "stream",
|
|
||||||
"name": "stderr",
|
|
||||||
"text": [
|
|
||||||
"colossalai - torch.distributed.distributed_c10d - 2021-10-15 03:27:51,596 INFO: Added key: store_based_barrier_key:1 to store for rank: 0\n",
|
|
||||||
"colossalai - torch.distributed.distributed_c10d - 2021-10-15 03:27:51,598 INFO: Rank 0: Completed store-based barrier for 1 nodes.\n",
|
|
||||||
"colossalai - torch.distributed.distributed_c10d - 2021-10-15 03:27:51,602 INFO: Added key: store_based_barrier_key:2 to store for rank: 0\n",
|
|
||||||
"colossalai - torch.distributed.distributed_c10d - 2021-10-15 03:27:51,605 INFO: Rank 0: Completed store-based barrier for 1 nodes.\n",
|
|
||||||
"colossalai - torch.distributed.distributed_c10d - 2021-10-15 03:27:51,608 INFO: Added key: store_based_barrier_key:3 to store for rank: 0\n",
|
|
||||||
"colossalai - torch.distributed.distributed_c10d - 2021-10-15 03:27:51,610 INFO: Rank 0: Completed store-based barrier for 1 nodes.\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"output_type": "stream",
|
|
||||||
"name": "stdout",
|
|
||||||
"text": [
|
|
||||||
"process rank 0 is bound to device 0\n",
|
|
||||||
"initialized seed on rank 0, numpy: 1024, python random: 1024, ParallelMode.DATA: 1024, ParallelMode.TENSOR: 1124,the default parallel seed is ParallelMode.DATA.\n"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"metadata": {
|
|
||||||
"id": "ppjmMxc_81TK"
|
|
||||||
},
|
|
||||||
"source": [
|
|
||||||
"Load and normalize the CIFAR10 training and test datasets using `colossalai.nn.data`. Note that we have wrapped `torchvision.transforms`, so that we can simply use the config dict to use them."
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"metadata": {
|
|
||||||
"id": "ZyGhyD47-dUY",
|
|
||||||
"colab": {
|
|
||||||
"base_uri": "https://localhost:8080/"
|
|
||||||
},
|
|
||||||
"outputId": "98bbf2d1-a1c4-4bb4-b6df-600777b1e8f5"
|
|
||||||
},
|
|
||||||
"source": [
|
|
||||||
"transform_cfg = [\n",
|
|
||||||
" dict(type='ToTensor'),\n",
|
|
||||||
" dict(type='Normalize',\n",
|
|
||||||
" mean=[0.4914, 0.4822, 0.4465],\n",
|
|
||||||
" std=[0.2023, 0.1994, 0.2010]),\n",
|
|
||||||
"]\n",
|
|
||||||
"\n",
|
|
||||||
"batch_size = 128\n",
|
|
||||||
"\n",
|
|
||||||
"trainset = colossalai.nn.data.CIFAR10Dataset(transform_cfg, root='./data', train=True)\n",
|
|
||||||
"trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2)\n",
|
|
||||||
"\n",
|
|
||||||
"testset = colossalai.nn.data.CIFAR10Dataset(transform_cfg, root='./data', train=False)\n",
|
|
||||||
"testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=2)"
|
|
||||||
],
|
|
||||||
"execution_count": null,
|
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"output_type": "stream",
|
|
||||||
"name": "stdout",
|
|
||||||
"text": [
|
|
||||||
"Files already downloaded and verified\n",
|
|
||||||
"Files already downloaded and verified\n"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"metadata": {
|
|
||||||
"id": "NvPbfLLR9NzC"
|
|
||||||
},
|
|
||||||
"source": [
|
|
||||||
"We just define a simple Convolutional Neural Network here."
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"metadata": {
|
|
||||||
"id": "cQ_y7lBG09LS"
|
|
||||||
},
|
|
||||||
"source": [
|
|
||||||
"import torch.nn as nn\n",
|
|
||||||
"import torch.nn.functional as F\n",
|
|
||||||
"\n",
|
|
||||||
"\n",
|
|
||||||
"class Net(nn.Module):\n",
|
|
||||||
" def __init__(self):\n",
|
|
||||||
" super().__init__()\n",
|
|
||||||
" self.conv1 = nn.Conv2d(3, 6, 5)\n",
|
|
||||||
" self.pool = nn.MaxPool2d(2, 2)\n",
|
|
||||||
" self.conv2 = nn.Conv2d(6, 16, 5)\n",
|
|
||||||
" self.fc1 = nn.Linear(16 * 5 * 5, 120)\n",
|
|
||||||
" self.fc2 = nn.Linear(120, 84)\n",
|
|
||||||
" self.fc3 = nn.Linear(84, 10)\n",
|
|
||||||
"\n",
|
|
||||||
" def forward(self, x):\n",
|
|
||||||
" x = self.pool(F.relu(self.conv1(x)))\n",
|
|
||||||
" x = self.pool(F.relu(self.conv2(x)))\n",
|
|
||||||
" x = torch.flatten(x, 1) # flatten all dimensions except batch\n",
|
|
||||||
" x = F.relu(self.fc1(x))\n",
|
|
||||||
" x = F.relu(self.fc2(x))\n",
|
|
||||||
" x = self.fc3(x)\n",
|
|
||||||
" return x\n",
|
|
||||||
"\n",
|
|
||||||
"\n",
|
|
||||||
"model = Net().cuda()"
|
|
||||||
],
|
|
||||||
"execution_count": null,
|
|
||||||
"outputs": []
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"metadata": {
|
|
||||||
"id": "tgsszAmM9dYZ"
|
|
||||||
},
|
|
||||||
"source": [
|
|
||||||
"Define a Loss function and optimizer. And then we use them to initialize `Engine` and `Trainer`. We provide various training / evaluating hooks. In this case, we just use the simplest hooks which can compute and print loss and accuracy."
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"metadata": {
|
|
||||||
"colab": {
|
|
||||||
"base_uri": "https://localhost:8080/"
|
|
||||||
},
|
|
||||||
"id": "YtaDoCax1BCf",
|
|
||||||
"outputId": "b33b1641-03d8-4597-c8c2-1a4c1d61e9b0"
|
|
||||||
},
|
|
||||||
"source": [
|
|
||||||
"import torch.optim as optim\n",
|
|
||||||
"\n",
|
|
||||||
"criterion = nn.CrossEntropyLoss()\n",
|
|
||||||
"optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)\n",
|
|
||||||
"schedule = NoPipelineSchedule()\n",
|
|
||||||
"engine = Engine(\n",
|
|
||||||
" model=model,\n",
|
|
||||||
" criterion=criterion,\n",
|
|
||||||
" optimizer=optimizer,\n",
|
|
||||||
" lr_scheduler=None,\n",
|
|
||||||
" schedule=schedule\n",
|
|
||||||
" )\n",
|
|
||||||
"trainer = Trainer(engine=engine,\n",
|
|
||||||
" hooks_cfg=[dict(type='LossHook'), dict(type='LogMetricByEpochHook'), dict(type='AccuracyHook')],\n",
|
|
||||||
" verbose=True)"
|
|
||||||
],
|
|
||||||
"execution_count": null,
|
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"output_type": "stream",
|
|
||||||
"name": "stderr",
|
|
||||||
"text": [
|
|
||||||
"colossalai - rank_0 - 2021-10-15 03:27:56,018 WARNING: No gradient handler is set up, please make sure you do not need to all-reduce the gradients after a training step.\n",
|
|
||||||
"colossalai - rank_0 - 2021-10-15 03:27:56,024 INFO: build LogMetricByEpochHook for train, priority = 1\n",
|
|
||||||
"colossalai - rank_0 - 2021-10-15 03:27:56,026 INFO: build LossHook for train, priority = 10\n",
|
|
||||||
"colossalai - rank_0 - 2021-10-15 03:27:56,029 INFO: build AccuracyHook for train, priority = 10\n"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"metadata": {
|
|
||||||
"id": "_JR2TuvH99Ik"
|
|
||||||
},
|
|
||||||
"source": [
|
|
||||||
"Then we set training configs. We train our model for 10 epochs and it will be evaluated every 1 epoch. Set `display_progress` to `True` to display the training / evaluating progress bar."
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"metadata": {
|
|
||||||
"colab": {
|
|
||||||
"base_uri": "https://localhost:8080/"
|
|
||||||
},
|
|
||||||
"id": "w-J3IP-J1sfx",
|
|
||||||
"outputId": "bdb76939-04f1-4124-ce5e-3af44c0d902c"
|
|
||||||
},
|
|
||||||
"source": [
|
|
||||||
"num_epochs = 10\n",
|
|
||||||
"test_interval = 1\n",
|
|
||||||
"trainer.fit(\n",
|
|
||||||
" train_dataloader=trainloader,\n",
|
|
||||||
" test_dataloader=testloader,\n",
|
|
||||||
" max_epochs=num_epochs,\n",
|
|
||||||
" display_progress=True,\n",
|
|
||||||
" test_interval=test_interval\n",
|
|
||||||
" )"
|
|
||||||
],
|
|
||||||
"execution_count": null,
|
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"output_type": "stream",
|
|
||||||
"name": "stderr",
|
|
||||||
"text": [
|
|
||||||
"[Epoch 0 train]: 0%| | 0/391 [00:00<?, ?it/s]/usr/local/lib/python3.7/dist-packages/torch/nn/functional.py:718: UserWarning: Named tensors and all their associated APIs are an experimental feature and subject to change. Please do not use them for anything important until they are released as stable. (Triggered internally at /pytorch/c10/core/TensorImpl.h:1156.)\n",
|
|
||||||
" return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)\n",
|
|
||||||
"[Epoch 0 train]: 100%|██████████| 391/391 [00:14<00:00, 26.82it/s]\n",
|
|
||||||
"colossalai - rank_0 - 2021-10-15 03:28:11,088 INFO: Training - Epoch 1 - LogMetricByEpochHook: Loss = 2.29158\n",
|
|
||||||
"[Epoch 0 val]: 100%|██████████| 79/79 [00:02<00:00, 28.66it/s]\n",
|
|
||||||
"colossalai - rank_0 - 2021-10-15 03:28:14,040 INFO: Testing - Epoch 1 - LogMetricByEpochHook: Loss = 2.26517, Accuracy = 0.14820\n",
|
|
||||||
"[Epoch 1 train]: 100%|██████████| 391/391 [00:14<00:00, 26.31it/s]\n",
|
|
||||||
"colossalai - rank_0 - 2021-10-15 03:28:29,059 INFO: Training - Epoch 2 - LogMetricByEpochHook: Loss = 2.15763\n",
|
|
||||||
"[Epoch 1 val]: 100%|██████████| 79/79 [00:02<00:00, 28.50it/s]\n",
|
|
||||||
"colossalai - rank_0 - 2021-10-15 03:28:32,007 INFO: Testing - Epoch 2 - LogMetricByEpochHook: Loss = 2.00450, Accuracy = 0.27850\n",
|
|
||||||
"[Epoch 2 train]: 100%|██████████| 391/391 [00:14<00:00, 26.08it/s]\n",
|
|
||||||
"colossalai - rank_0 - 2021-10-15 03:28:47,167 INFO: Training - Epoch 3 - LogMetricByEpochHook: Loss = 1.85409\n",
|
|
||||||
"[Epoch 2 val]: 100%|██████████| 79/79 [00:02<00:00, 27.89it/s]\n",
|
|
||||||
"colossalai - rank_0 - 2021-10-15 03:28:50,168 INFO: Testing - Epoch 3 - LogMetricByEpochHook: Loss = 1.73788, Accuracy = 0.35990\n",
|
|
||||||
"[Epoch 3 train]: 100%|██████████| 391/391 [00:14<00:00, 26.09it/s]\n",
|
|
||||||
"colossalai - rank_0 - 2021-10-15 03:29:05,330 INFO: Training - Epoch 4 - LogMetricByEpochHook: Loss = 1.69363\n",
|
|
||||||
"[Epoch 3 val]: 100%|██████████| 79/79 [00:02<00:00, 28.43it/s]\n",
|
|
||||||
"colossalai - rank_0 - 2021-10-15 03:29:08,290 INFO: Testing - Epoch 4 - LogMetricByEpochHook: Loss = 1.65005, Accuracy = 0.39350\n",
|
|
||||||
"[Epoch 4 train]: 100%|██████████| 391/391 [00:15<00:00, 25.97it/s]\n",
|
|
||||||
"colossalai - rank_0 - 2021-10-15 03:29:23,530 INFO: Training - Epoch 5 - LogMetricByEpochHook: Loss = 1.61387\n",
|
|
||||||
"[Epoch 4 val]: 100%|██████████| 79/79 [00:02<00:00, 27.75it/s]\n",
|
|
||||||
"colossalai - rank_0 - 2021-10-15 03:29:26,515 INFO: Testing - Epoch 5 - LogMetricByEpochHook: Loss = 1.57507, Accuracy = 0.42430\n",
|
|
||||||
"[Epoch 5 train]: 100%|██████████| 391/391 [00:15<00:00, 25.92it/s]\n",
|
|
||||||
"colossalai - rank_0 - 2021-10-15 03:29:41,764 INFO: Training - Epoch 6 - LogMetricByEpochHook: Loss = 1.55712\n",
|
|
||||||
"[Epoch 5 val]: 100%|██████████| 79/79 [00:02<00:00, 27.51it/s]\n",
|
|
||||||
"colossalai - rank_0 - 2021-10-15 03:29:44,778 INFO: Testing - Epoch 6 - LogMetricByEpochHook: Loss = 1.53242, Accuracy = 0.43700\n",
|
|
||||||
"[Epoch 6 train]: 100%|██████████| 391/391 [00:14<00:00, 26.13it/s]\n",
|
|
||||||
"colossalai - rank_0 - 2021-10-15 03:29:59,927 INFO: Training - Epoch 7 - LogMetricByEpochHook: Loss = 1.51618\n",
|
|
||||||
"[Epoch 6 val]: 100%|██████████| 79/79 [00:02<00:00, 28.31it/s]\n",
|
|
||||||
"colossalai - rank_0 - 2021-10-15 03:30:02,884 INFO: Testing - Epoch 7 - LogMetricByEpochHook: Loss = 1.49720, Accuracy = 0.45430\n",
|
|
||||||
"[Epoch 7 train]: 100%|██████████| 391/391 [00:14<00:00, 26.23it/s]\n",
|
|
||||||
"colossalai - rank_0 - 2021-10-15 03:30:17,968 INFO: Training - Epoch 8 - LogMetricByEpochHook: Loss = 1.47857\n",
|
|
||||||
"[Epoch 7 val]: 100%|██████████| 79/79 [00:02<00:00, 27.97it/s]\n",
|
|
||||||
"colossalai - rank_0 - 2021-10-15 03:30:20,967 INFO: Testing - Epoch 8 - LogMetricByEpochHook: Loss = 1.45808, Accuracy = 0.46320\n",
|
|
||||||
"[Epoch 8 train]: 100%|██████████| 391/391 [00:14<00:00, 26.11it/s]\n",
|
|
||||||
"colossalai - rank_0 - 2021-10-15 03:30:36,129 INFO: Training - Epoch 9 - LogMetricByEpochHook: Loss = 1.44656\n",
|
|
||||||
"[Epoch 8 val]: 100%|██████████| 79/79 [00:02<00:00, 28.18it/s]\n",
|
|
||||||
"colossalai - rank_0 - 2021-10-15 03:30:39,096 INFO: Testing - Epoch 9 - LogMetricByEpochHook: Loss = 1.44903, Accuracy = 0.46580\n",
|
|
||||||
"[Epoch 9 train]: 100%|██████████| 391/391 [00:15<00:00, 25.97it/s]\n",
|
|
||||||
"colossalai - rank_0 - 2021-10-15 03:30:54,342 INFO: Training - Epoch 10 - LogMetricByEpochHook: Loss = 1.41120\n",
|
|
||||||
"[Epoch 9 val]: 100%|██████████| 79/79 [00:02<00:00, 28.05it/s]\n",
|
|
||||||
"colossalai - rank_0 - 2021-10-15 03:30:57,332 INFO: Testing - Epoch 10 - LogMetricByEpochHook: Loss = 1.41242, Accuracy = 0.48500\n"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"metadata": {
|
||||||
|
"colab": {
|
||||||
|
"base_uri": "https://localhost:8080/"
|
||||||
|
},
|
||||||
|
"id": "UVKEurtS4SFS",
|
||||||
|
"outputId": "99fb6050-5da7-4f27-b4eb-9b3ccf830efb"
|
||||||
|
},
|
||||||
|
"source": [
|
||||||
|
"import colossalai\n",
|
||||||
|
"from colossalai.engine import Engine, NoPipelineSchedule\n",
|
||||||
|
"from colossalai.trainer import Trainer\n",
|
||||||
|
"from colossalai.context import Config\n",
|
||||||
|
"import torch"
|
||||||
|
],
|
||||||
|
"execution_count": null,
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"output_type": "stream",
|
||||||
|
"name": "stdout",
|
||||||
|
"text": [
|
||||||
|
"Please install apex to use FP16 Optimizer\n",
|
||||||
|
"Apex should be installed to use the FP16 optimizer\n",
|
||||||
|
"apex is required for mixed precision training\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {
|
||||||
|
"id": "PpFfhNBD7NSn"
|
||||||
|
},
|
||||||
|
"source": [
|
||||||
|
"First, we should initialize distributed environment. Though we just use single GPU in this example, we still need initialize distributed environment for compatibility. We just consider the simplest case here, so we just set the number of parallel processes to 1."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"metadata": {
|
||||||
|
"colab": {
|
||||||
|
"base_uri": "https://localhost:8080/"
|
||||||
|
},
|
||||||
|
"id": "8yF7Lc-K7NAS",
|
||||||
|
"outputId": "01312349-a8b0-4de4-9103-7d1b48e6cc36"
|
||||||
|
},
|
||||||
|
"source": [
|
||||||
|
"parallel_cfg = Config(dict(parallel=dict(\n",
|
||||||
|
" data=dict(size=1),\n",
|
||||||
|
" pipeline=dict(size=1),\n",
|
||||||
|
" tensor=dict(size=1, mode=None),\n",
|
||||||
|
")))\n",
|
||||||
|
"colossalai.init_dist(config=parallel_cfg,\n",
|
||||||
|
" local_rank=0,\n",
|
||||||
|
" world_size=1,\n",
|
||||||
|
" host='127.0.0.1',\n",
|
||||||
|
" port=8888,\n",
|
||||||
|
" backend='nccl')"
|
||||||
|
],
|
||||||
|
"execution_count": null,
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"output_type": "stream",
|
||||||
|
"name": "stderr",
|
||||||
|
"text": [
|
||||||
|
"colossalai - torch.distributed.distributed_c10d - 2021-10-15 03:27:51,596 INFO: Added key: store_based_barrier_key:1 to store for rank: 0\n",
|
||||||
|
"colossalai - torch.distributed.distributed_c10d - 2021-10-15 03:27:51,598 INFO: Rank 0: Completed store-based barrier for 1 nodes.\n",
|
||||||
|
"colossalai - torch.distributed.distributed_c10d - 2021-10-15 03:27:51,602 INFO: Added key: store_based_barrier_key:2 to store for rank: 0\n",
|
||||||
|
"colossalai - torch.distributed.distributed_c10d - 2021-10-15 03:27:51,605 INFO: Rank 0: Completed store-based barrier for 1 nodes.\n",
|
||||||
|
"colossalai - torch.distributed.distributed_c10d - 2021-10-15 03:27:51,608 INFO: Added key: store_based_barrier_key:3 to store for rank: 0\n",
|
||||||
|
"colossalai - torch.distributed.distributed_c10d - 2021-10-15 03:27:51,610 INFO: Rank 0: Completed store-based barrier for 1 nodes.\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"output_type": "stream",
|
||||||
|
"name": "stdout",
|
||||||
|
"text": [
|
||||||
|
"process rank 0 is bound to device 0\n",
|
||||||
|
"initialized seed on rank 0, numpy: 1024, python random: 1024, ParallelMode.DATA: 1024, ParallelMode.TENSOR: 1124,the default parallel seed is ParallelMode.DATA.\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {
|
||||||
|
"id": "ppjmMxc_81TK"
|
||||||
|
},
|
||||||
|
"source": [
|
||||||
|
"Load and normalize the CIFAR10 training and test datasets using `colossalai.nn.data`. Note that we have wrapped `torchvision.transforms`, so that we can simply use the config dict to use them."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"metadata": {
|
||||||
|
"id": "ZyGhyD47-dUY",
|
||||||
|
"colab": {
|
||||||
|
"base_uri": "https://localhost:8080/"
|
||||||
|
},
|
||||||
|
"outputId": "98bbf2d1-a1c4-4bb4-b6df-600777b1e8f5"
|
||||||
|
},
|
||||||
|
"source": [
|
||||||
|
"transform_cfg = [\n",
|
||||||
|
" dict(type='ToTensor'),\n",
|
||||||
|
" dict(type='Normalize',\n",
|
||||||
|
" mean=[0.4914, 0.4822, 0.4465],\n",
|
||||||
|
" std=[0.2023, 0.1994, 0.2010]),\n",
|
||||||
|
"]\n",
|
||||||
|
"\n",
|
||||||
|
"batch_size = 128\n",
|
||||||
|
"\n",
|
||||||
|
"trainset = colossalai.nn.data.CIFAR10Dataset(transform_cfg, root='./data', train=True)\n",
|
||||||
|
"trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2)\n",
|
||||||
|
"\n",
|
||||||
|
"testset = colossalai.nn.data.CIFAR10Dataset(transform_cfg, root='./data', train=False)\n",
|
||||||
|
"testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=2)"
|
||||||
|
],
|
||||||
|
"execution_count": null,
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"output_type": "stream",
|
||||||
|
"name": "stdout",
|
||||||
|
"text": [
|
||||||
|
"Files already downloaded and verified\n",
|
||||||
|
"Files already downloaded and verified\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {
|
||||||
|
"id": "NvPbfLLR9NzC"
|
||||||
|
},
|
||||||
|
"source": [
|
||||||
|
"We just define a simple Convolutional Neural Network here."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"metadata": {
|
||||||
|
"id": "cQ_y7lBG09LS"
|
||||||
|
},
|
||||||
|
"source": [
|
||||||
|
"import torch.nn as nn\n",
|
||||||
|
"import torch.nn.functional as F\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"class Net(nn.Module):\n",
|
||||||
|
" def __init__(self):\n",
|
||||||
|
" super().__init__()\n",
|
||||||
|
" self.conv1 = nn.Conv2d(3, 6, 5)\n",
|
||||||
|
" self.pool = nn.MaxPool2d(2, 2)\n",
|
||||||
|
" self.conv2 = nn.Conv2d(6, 16, 5)\n",
|
||||||
|
" self.fc1 = nn.Linear(16 * 5 * 5, 120)\n",
|
||||||
|
" self.fc2 = nn.Linear(120, 84)\n",
|
||||||
|
" self.fc3 = nn.Linear(84, 10)\n",
|
||||||
|
"\n",
|
||||||
|
" def forward(self, x):\n",
|
||||||
|
" x = self.pool(F.relu(self.conv1(x)))\n",
|
||||||
|
" x = self.pool(F.relu(self.conv2(x)))\n",
|
||||||
|
" x = torch.flatten(x, 1) # flatten all dimensions except batch\n",
|
||||||
|
" x = F.relu(self.fc1(x))\n",
|
||||||
|
" x = F.relu(self.fc2(x))\n",
|
||||||
|
" x = self.fc3(x)\n",
|
||||||
|
" return x\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"model = Net().cuda()"
|
||||||
|
],
|
||||||
|
"execution_count": null,
|
||||||
|
"outputs": []
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {
|
||||||
|
"id": "tgsszAmM9dYZ"
|
||||||
|
},
|
||||||
|
"source": [
|
||||||
|
"Define a Loss function and optimizer. And then we use them to initialize `Engine` and `Trainer`. We provide various training / evaluating hooks. In this case, we just use the simplest hooks which can compute and print loss and accuracy."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"metadata": {
|
||||||
|
"colab": {
|
||||||
|
"base_uri": "https://localhost:8080/"
|
||||||
|
},
|
||||||
|
"id": "YtaDoCax1BCf",
|
||||||
|
"outputId": "b33b1641-03d8-4597-c8c2-1a4c1d61e9b0"
|
||||||
|
},
|
||||||
|
"source": [
|
||||||
|
"import torch.optim as optim\n",
|
||||||
|
"\n",
|
||||||
|
"criterion = nn.CrossEntropyLoss()\n",
|
||||||
|
"optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)\n",
|
||||||
|
"schedule = NoPipelineSchedule()\n",
|
||||||
|
"engine = Engine(\n",
|
||||||
|
" model=model,\n",
|
||||||
|
" criterion=criterion,\n",
|
||||||
|
" optimizer=optimizer,\n",
|
||||||
|
" lr_scheduler=None,\n",
|
||||||
|
" schedule=schedule\n",
|
||||||
|
" )\n",
|
||||||
|
"trainer = Trainer(engine=engine,\n",
|
||||||
|
" hooks_cfg=[dict(type='LossHook'), dict(type='LogMetricByEpochHook'), dict(type='AccuracyHook')],\n",
|
||||||
|
" verbose=True)"
|
||||||
|
],
|
||||||
|
"execution_count": null,
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"output_type": "stream",
|
||||||
|
"name": "stderr",
|
||||||
|
"text": [
|
||||||
|
"colossalai - rank_0 - 2021-10-15 03:27:56,018 WARNING: No gradient handler is set up, please make sure you do not need to all-reduce the gradients after a training step.\n",
|
||||||
|
"colossalai - rank_0 - 2021-10-15 03:27:56,024 INFO: build LogMetricByEpochHook for train, priority = 1\n",
|
||||||
|
"colossalai - rank_0 - 2021-10-15 03:27:56,026 INFO: build LossHook for train, priority = 10\n",
|
||||||
|
"colossalai - rank_0 - 2021-10-15 03:27:56,029 INFO: build AccuracyHook for train, priority = 10\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {
|
||||||
|
"id": "_JR2TuvH99Ik"
|
||||||
|
},
|
||||||
|
"source": [
|
||||||
|
"Then we set training configs. We train our model for 10 epochs and it will be evaluated every 1 epoch. Set `display_progress` to `True` to display the training / evaluating progress bar."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"metadata": {
|
||||||
|
"colab": {
|
||||||
|
"base_uri": "https://localhost:8080/"
|
||||||
|
},
|
||||||
|
"id": "w-J3IP-J1sfx",
|
||||||
|
"outputId": "bdb76939-04f1-4124-ce5e-3af44c0d902c"
|
||||||
|
},
|
||||||
|
"source": [
|
||||||
|
"num_epochs = 10\n",
|
||||||
|
"test_interval = 1\n",
|
||||||
|
"trainer.fit(\n",
|
||||||
|
" train_dataloader=trainloader,\n",
|
||||||
|
" test_dataloader=testloader,\n",
|
||||||
|
" max_epochs=num_epochs,\n",
|
||||||
|
" display_progress=True,\n",
|
||||||
|
" test_interval=test_interval\n",
|
||||||
|
" )"
|
||||||
|
],
|
||||||
|
"execution_count": null,
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"output_type": "stream",
|
||||||
|
"name": "stderr",
|
||||||
|
"text": [
|
||||||
|
"[Epoch 0 train]: 0%| | 0/391 [00:00<?, ?it/s]/usr/local/lib/python3.7/dist-packages/torch/nn/functional.py:718: UserWarning: Named tensors and all their associated APIs are an experimental feature and subject to change. Please do not use them for anything important until they are released as stable. (Triggered internally at /pytorch/c10/core/TensorImpl.h:1156.)\n",
|
||||||
|
" return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)\n",
|
||||||
|
"[Epoch 0 train]: 100%|██████████| 391/391 [00:14<00:00, 26.82it/s]\n",
|
||||||
|
"colossalai - rank_0 - 2021-10-15 03:28:11,088 INFO: Training - Epoch 1 - LogMetricByEpochHook: Loss = 2.29158\n",
|
||||||
|
"[Epoch 0 val]: 100%|██████████| 79/79 [00:02<00:00, 28.66it/s]\n",
|
||||||
|
"colossalai - rank_0 - 2021-10-15 03:28:14,040 INFO: Testing - Epoch 1 - LogMetricByEpochHook: Loss = 2.26517, Accuracy = 0.14820\n",
|
||||||
|
"[Epoch 1 train]: 100%|██████████| 391/391 [00:14<00:00, 26.31it/s]\n",
|
||||||
|
"colossalai - rank_0 - 2021-10-15 03:28:29,059 INFO: Training - Epoch 2 - LogMetricByEpochHook: Loss = 2.15763\n",
|
||||||
|
"[Epoch 1 val]: 100%|██████████| 79/79 [00:02<00:00, 28.50it/s]\n",
|
||||||
|
"colossalai - rank_0 - 2021-10-15 03:28:32,007 INFO: Testing - Epoch 2 - LogMetricByEpochHook: Loss = 2.00450, Accuracy = 0.27850\n",
|
||||||
|
"[Epoch 2 train]: 100%|██████████| 391/391 [00:14<00:00, 26.08it/s]\n",
|
||||||
|
"colossalai - rank_0 - 2021-10-15 03:28:47,167 INFO: Training - Epoch 3 - LogMetricByEpochHook: Loss = 1.85409\n",
|
||||||
|
"[Epoch 2 val]: 100%|██████████| 79/79 [00:02<00:00, 27.89it/s]\n",
|
||||||
|
"colossalai - rank_0 - 2021-10-15 03:28:50,168 INFO: Testing - Epoch 3 - LogMetricByEpochHook: Loss = 1.73788, Accuracy = 0.35990\n",
|
||||||
|
"[Epoch 3 train]: 100%|██████████| 391/391 [00:14<00:00, 26.09it/s]\n",
|
||||||
|
"colossalai - rank_0 - 2021-10-15 03:29:05,330 INFO: Training - Epoch 4 - LogMetricByEpochHook: Loss = 1.69363\n",
|
||||||
|
"[Epoch 3 val]: 100%|██████████| 79/79 [00:02<00:00, 28.43it/s]\n",
|
||||||
|
"colossalai - rank_0 - 2021-10-15 03:29:08,290 INFO: Testing - Epoch 4 - LogMetricByEpochHook: Loss = 1.65005, Accuracy = 0.39350\n",
|
||||||
|
"[Epoch 4 train]: 100%|██████████| 391/391 [00:15<00:00, 25.97it/s]\n",
|
||||||
|
"colossalai - rank_0 - 2021-10-15 03:29:23,530 INFO: Training - Epoch 5 - LogMetricByEpochHook: Loss = 1.61387\n",
|
||||||
|
"[Epoch 4 val]: 100%|██████████| 79/79 [00:02<00:00, 27.75it/s]\n",
|
||||||
|
"colossalai - rank_0 - 2021-10-15 03:29:26,515 INFO: Testing - Epoch 5 - LogMetricByEpochHook: Loss = 1.57507, Accuracy = 0.42430\n",
|
||||||
|
"[Epoch 5 train]: 100%|██████████| 391/391 [00:15<00:00, 25.92it/s]\n",
|
||||||
|
"colossalai - rank_0 - 2021-10-15 03:29:41,764 INFO: Training - Epoch 6 - LogMetricByEpochHook: Loss = 1.55712\n",
|
||||||
|
"[Epoch 5 val]: 100%|██████████| 79/79 [00:02<00:00, 27.51it/s]\n",
|
||||||
|
"colossalai - rank_0 - 2021-10-15 03:29:44,778 INFO: Testing - Epoch 6 - LogMetricByEpochHook: Loss = 1.53242, Accuracy = 0.43700\n",
|
||||||
|
"[Epoch 6 train]: 100%|██████████| 391/391 [00:14<00:00, 26.13it/s]\n",
|
||||||
|
"colossalai - rank_0 - 2021-10-15 03:29:59,927 INFO: Training - Epoch 7 - LogMetricByEpochHook: Loss = 1.51618\n",
|
||||||
|
"[Epoch 6 val]: 100%|██████████| 79/79 [00:02<00:00, 28.31it/s]\n",
|
||||||
|
"colossalai - rank_0 - 2021-10-15 03:30:02,884 INFO: Testing - Epoch 7 - LogMetricByEpochHook: Loss = 1.49720, Accuracy = 0.45430\n",
|
||||||
|
"[Epoch 7 train]: 100%|██████████| 391/391 [00:14<00:00, 26.23it/s]\n",
|
||||||
|
"colossalai - rank_0 - 2021-10-15 03:30:17,968 INFO: Training - Epoch 8 - LogMetricByEpochHook: Loss = 1.47857\n",
|
||||||
|
"[Epoch 7 val]: 100%|██████████| 79/79 [00:02<00:00, 27.97it/s]\n",
|
||||||
|
"colossalai - rank_0 - 2021-10-15 03:30:20,967 INFO: Testing - Epoch 8 - LogMetricByEpochHook: Loss = 1.45808, Accuracy = 0.46320\n",
|
||||||
|
"[Epoch 8 train]: 100%|██████████| 391/391 [00:14<00:00, 26.11it/s]\n",
|
||||||
|
"colossalai - rank_0 - 2021-10-15 03:30:36,129 INFO: Training - Epoch 9 - LogMetricByEpochHook: Loss = 1.44656\n",
|
||||||
|
"[Epoch 8 val]: 100%|██████████| 79/79 [00:02<00:00, 28.18it/s]\n",
|
||||||
|
"colossalai - rank_0 - 2021-10-15 03:30:39,096 INFO: Testing - Epoch 9 - LogMetricByEpochHook: Loss = 1.44903, Accuracy = 0.46580\n",
|
||||||
|
"[Epoch 9 train]: 100%|██████████| 391/391 [00:15<00:00, 25.97it/s]\n",
|
||||||
|
"colossalai - rank_0 - 2021-10-15 03:30:54,342 INFO: Training - Epoch 10 - LogMetricByEpochHook: Loss = 1.41120\n",
|
||||||
|
"[Epoch 9 val]: 100%|██████████| 79/79 [00:02<00:00, 28.05it/s]\n",
|
||||||
|
"colossalai - rank_0 - 2021-10-15 03:30:57,332 INFO: Testing - Epoch 10 - LogMetricByEpochHook: Loss = 1.41242, Accuracy = 0.48500\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
}
|
}
|
|
@ -3,26 +3,18 @@
|
||||||
|
|
||||||
import colossalai
|
import colossalai
|
||||||
from colossalai.core import global_context as gpc
|
from colossalai.core import global_context as gpc
|
||||||
from colossalai.engine import Engine
|
|
||||||
from colossalai.logging import get_global_dist_logger
|
from colossalai.logging import get_global_dist_logger
|
||||||
from colossalai.trainer import Trainer
|
from colossalai.trainer import Trainer
|
||||||
|
|
||||||
|
|
||||||
def run_trainer():
|
def run_trainer():
|
||||||
model, train_dataloader, test_dataloader, criterion, optimizer, schedule, lr_scheduler = colossalai.initialize()
|
engine, train_dataloader, test_dataloader = colossalai.initialize()
|
||||||
logger = get_global_dist_logger()
|
logger = get_global_dist_logger()
|
||||||
schedule.data_sync = False
|
engine.schedule.data_sync = False
|
||||||
engine = Engine(
|
|
||||||
model=model,
|
|
||||||
criterion=criterion,
|
|
||||||
optimizer=optimizer,
|
|
||||||
lr_scheduler=lr_scheduler,
|
|
||||||
schedule=schedule
|
|
||||||
)
|
|
||||||
logger.info("engine is built", ranks=[0])
|
logger.info("engine is built", ranks=[0])
|
||||||
|
|
||||||
trainer = Trainer(engine=engine,
|
trainer = Trainer(engine=engine,
|
||||||
hooks_cfg=gpc.config.hooks,
|
|
||||||
verbose=True)
|
verbose=True)
|
||||||
logger.info("trainer is built", ranks=[0])
|
logger.info("trainer is built", ranks=[0])
|
||||||
|
|
||||||
|
@ -30,7 +22,8 @@ def run_trainer():
|
||||||
trainer.fit(
|
trainer.fit(
|
||||||
train_dataloader=train_dataloader,
|
train_dataloader=train_dataloader,
|
||||||
test_dataloader=test_dataloader,
|
test_dataloader=test_dataloader,
|
||||||
max_epochs=gpc.config.num_epochs,
|
epochs=gpc.config.num_epochs,
|
||||||
|
hooks_cfg=gpc.config.hooks,
|
||||||
display_progress=True,
|
display_progress=True,
|
||||||
test_interval=2
|
test_interval=2
|
||||||
)
|
)
|
||||||
|
|
|
@ -3,5 +3,5 @@ torchvision>=0.9
|
||||||
numpy
|
numpy
|
||||||
tqdm
|
tqdm
|
||||||
psutil
|
psutil
|
||||||
tensorboardX
|
tensorboard
|
||||||
packaging
|
packaging
|
2
setup.py
2
setup.py
|
@ -121,7 +121,7 @@ if "--cuda_ext" in sys.argv:
|
||||||
install_requires = fetch_requirements('requirements/requirements.txt')
|
install_requires = fetch_requirements('requirements/requirements.txt')
|
||||||
|
|
||||||
setup(
|
setup(
|
||||||
name='colossal-ai',
|
name='colossalai',
|
||||||
version='0.0.1-beta',
|
version='0.0.1-beta',
|
||||||
packages=find_packages(exclude=('csrc',
|
packages=find_packages(exclude=('csrc',
|
||||||
'tests',
|
'tests',
|
||||||
|
|
|
@ -27,8 +27,6 @@ train_data = dict(
|
||||||
dataloader=dict(
|
dataloader=dict(
|
||||||
batch_size=BATCH_SIZE,
|
batch_size=BATCH_SIZE,
|
||||||
pin_memory=True,
|
pin_memory=True,
|
||||||
# num_workers=1,
|
|
||||||
# shuffle=True,
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -63,14 +61,6 @@ loss = dict(
|
||||||
type='CrossEntropyLoss2D',
|
type='CrossEntropyLoss2D',
|
||||||
)
|
)
|
||||||
|
|
||||||
# model = dict(
|
|
||||||
# type='VanillaResNet',
|
|
||||||
# block_type='ResNetBasicBlock',
|
|
||||||
# layers=[2, 2, 2, 2],
|
|
||||||
# num_cls=10
|
|
||||||
# )
|
|
||||||
|
|
||||||
|
|
||||||
model = dict(
|
model = dict(
|
||||||
type='VisionTransformerFromConfig',
|
type='VisionTransformerFromConfig',
|
||||||
tensor_splitting_cfg=dict(
|
tensor_splitting_cfg=dict(
|
||||||
|
@ -135,25 +125,26 @@ parallel = dict(
|
||||||
|
|
||||||
fp16 = dict(
|
fp16 = dict(
|
||||||
mode=AMP_TYPE.PARALLEL,
|
mode=AMP_TYPE.PARALLEL,
|
||||||
initial_scale=2 ** 8
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# fp16 = dict(
|
engine = dict(
|
||||||
# mode=None,
|
schedule=dict(
|
||||||
# )
|
num_microbatches=2
|
||||||
|
)
|
||||||
schedule = dict(
|
|
||||||
num_microbatches=2
|
|
||||||
)
|
|
||||||
lr_scheduler = dict(
|
|
||||||
type='LinearWarmupLR',
|
|
||||||
warmup_epochs=5
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
hooks = [
|
||||||
|
dict(
|
||||||
|
type='LRSchedulerHook',
|
||||||
|
by_epoch=True,
|
||||||
|
lr_scheduler_cfg=dict(
|
||||||
|
type='LinearWarmupLR',
|
||||||
|
warmup_steps=5
|
||||||
|
)
|
||||||
|
),
|
||||||
|
]
|
||||||
num_epochs = 60
|
num_epochs = 60
|
||||||
|
|
||||||
logging = dict(
|
logging = dict(
|
||||||
root_path='test_vit_2d_log'
|
root_path='test_vit_2d_log'
|
||||||
)
|
)
|
||||||
|
|
||||||
seed = 100
|
|
||||||
|
|
|
@ -124,14 +124,21 @@ parallel = dict(
|
||||||
tensor=dict(size=4, depth=1, mode='2.5d'),
|
tensor=dict(size=4, depth=1, mode='2.5d'),
|
||||||
)
|
)
|
||||||
|
|
||||||
lr_scheduler = dict(
|
hooks = [
|
||||||
type='LinearWarmupLR',
|
dict(
|
||||||
warmup_epochs=5
|
type='LRSchedulerHook',
|
||||||
)
|
by_epoch=True,
|
||||||
|
lr_scheduler_cfg=dict(
|
||||||
|
type='LinearWarmupLR',
|
||||||
|
warmup_steps=5
|
||||||
|
)
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
engine = dict(
|
||||||
schedule = dict(
|
schedule = dict(
|
||||||
num_microbatches=2
|
num_microbatches=2
|
||||||
)
|
)
|
||||||
|
)
|
||||||
|
|
||||||
num_epochs = 60
|
num_epochs = 60
|
||||||
num_microbatches = 1
|
|
||||||
|
|
|
@ -9,21 +9,22 @@ import torch.autograd
|
||||||
import colossalai
|
import colossalai
|
||||||
from colossalai.context.parallel_mode import ParallelMode
|
from colossalai.context.parallel_mode import ParallelMode
|
||||||
from colossalai.core import global_context as gpc
|
from colossalai.core import global_context as gpc
|
||||||
from colossalai.engine import Engine
|
|
||||||
from colossalai.logging import get_global_dist_logger
|
from colossalai.logging import get_global_dist_logger
|
||||||
from colossalai.nn.layer._parallel_utilities import _gather
|
from colossalai.nn.layer._parallel_utilities import _gather
|
||||||
|
|
||||||
CONFIG_PATH = Path(__file__).parent.parent.joinpath('configs/vit_2d.py')
|
CONFIG_PATH = Path(__file__).parent.parent.joinpath('configs/vit_2d.py')
|
||||||
|
|
||||||
|
|
||||||
def eval(engine):
|
def eval(engine, test_dataloader):
|
||||||
engine.eval()
|
engine.eval()
|
||||||
accumulated_loss = 0
|
accumulated_loss = 0
|
||||||
correct_sum = 0
|
correct_sum = 0
|
||||||
total_sum = 0
|
total_sum = 0
|
||||||
|
num_steps = len(test_dataloader)
|
||||||
|
data_iter = iter(test_dataloader)
|
||||||
|
|
||||||
for i in range(engine.schedule.num_steps):
|
for i in range(num_steps):
|
||||||
output, label, loss = engine.step()
|
output, label, loss = engine.step(data_iter)
|
||||||
|
|
||||||
if gpc.is_last_rank(ParallelMode.PIPELINE):
|
if gpc.is_last_rank(ParallelMode.PIPELINE):
|
||||||
# loss = sum(loss)
|
# loss = sum(loss)
|
||||||
|
@ -43,20 +44,22 @@ def eval(engine):
|
||||||
correct = torch.sum(label == output)
|
correct = torch.sum(label == output)
|
||||||
correct_sum += correct
|
correct_sum += correct
|
||||||
total_sum += label.size(0)
|
total_sum += label.size(0)
|
||||||
avg_loss = accumulated_loss / engine.schedule.num_steps
|
avg_loss = accumulated_loss / num_steps
|
||||||
return correct_sum, total_sum, avg_loss
|
return correct_sum, total_sum, avg_loss
|
||||||
|
|
||||||
|
|
||||||
def train(engine):
|
def train(engine, train_dataloader):
|
||||||
engine.train()
|
engine.train()
|
||||||
accumulated_loss = 0
|
accumulated_loss = 0
|
||||||
|
num_steps = len(train_dataloader)
|
||||||
|
data_iter = iter(train_dataloader)
|
||||||
|
|
||||||
for i in range(engine.schedule.num_steps):
|
for i in range(num_steps):
|
||||||
output, label, loss = engine.step()
|
output, label, loss = engine.step(data_iter)
|
||||||
|
|
||||||
if gpc.is_last_rank(ParallelMode.PIPELINE):
|
if gpc.is_last_rank(ParallelMode.PIPELINE):
|
||||||
accumulated_loss += loss.detach().cpu().numpy()
|
accumulated_loss += loss.detach().cpu().numpy()
|
||||||
avg_loss = accumulated_loss / engine.schedule.num_steps
|
avg_loss = accumulated_loss / num_steps
|
||||||
return avg_loss
|
return avg_loss
|
||||||
|
|
||||||
|
|
||||||
|
@ -64,25 +67,16 @@ def train(engine):
|
||||||
@pytest.mark.skip("This test should be invoked by test.sh in the same folder as it runs on multiple gpus")
|
@pytest.mark.skip("This test should be invoked by test.sh in the same folder as it runs on multiple gpus")
|
||||||
def test_2d_parallel_vision_transformer():
|
def test_2d_parallel_vision_transformer():
|
||||||
# init dist
|
# init dist
|
||||||
model, train_dataloader, test_dataloader, criterion, optimizer, schedule, lr_scheduler = colossalai.initialize(
|
engine, train_dataloader, test_dataloader = colossalai.initialize(CONFIG_PATH)
|
||||||
CONFIG_PATH)
|
|
||||||
logger = get_global_dist_logger()
|
logger = get_global_dist_logger()
|
||||||
|
|
||||||
engine = Engine(model=model,
|
|
||||||
train_dataloader=train_dataloader,
|
|
||||||
test_dataloader=test_dataloader,
|
|
||||||
criterion=criterion,
|
|
||||||
optimizer=optimizer,
|
|
||||||
lr_scheduler=lr_scheduler,
|
|
||||||
schedule=schedule)
|
|
||||||
|
|
||||||
for epoch in range(gpc.config.num_epochs):
|
for epoch in range(gpc.config.num_epochs):
|
||||||
train_loss = train(engine)
|
train_loss = train(engine, train_dataloader)
|
||||||
if gpc.is_last_rank(ParallelMode.PIPELINE):
|
if gpc.is_last_rank(ParallelMode.PIPELINE):
|
||||||
logger.info(f'epoch {epoch} - train loss: {train_loss}')
|
logger.info(f'epoch {epoch} - train loss: {train_loss}')
|
||||||
|
|
||||||
if epoch % 2 == 0:
|
if epoch % 2 == 0:
|
||||||
correct_sum, total_sum, eval_loss = eval(engine)
|
correct_sum, total_sum, eval_loss = eval(engine, test_dataloader)
|
||||||
if gpc.is_last_rank(ParallelMode.PIPELINE):
|
if gpc.is_last_rank(ParallelMode.PIPELINE):
|
||||||
logger.info(
|
logger.info(
|
||||||
f'epoch {epoch} - eval loss: {eval_loss}, total: {total_sum}, '
|
f'epoch {epoch} - eval loss: {eval_loss}, total: {total_sum}, '
|
||||||
|
|
|
@ -6,20 +6,22 @@ import torch.autograd
|
||||||
import colossalai
|
import colossalai
|
||||||
from colossalai.context.parallel_mode import ParallelMode
|
from colossalai.context.parallel_mode import ParallelMode
|
||||||
from colossalai.core import global_context as gpc
|
from colossalai.core import global_context as gpc
|
||||||
from colossalai.engine import Engine
|
|
||||||
from colossalai.logging import get_global_dist_logger
|
from colossalai.logging import get_global_dist_logger
|
||||||
from colossalai.nn.layer._parallel_utilities import _gather
|
from colossalai.nn.layer._parallel_utilities import _gather
|
||||||
|
|
||||||
CONFIG_PATH = Path(__file__).parent.parent.joinpath('configs/vit_2p5d.py')
|
CONFIG_PATH = Path(__file__).parent.parent.joinpath('configs/vit_2p5d.py')
|
||||||
|
|
||||||
def eval(engine):
|
|
||||||
|
def eval(engine, test_dataloader):
|
||||||
engine.eval()
|
engine.eval()
|
||||||
accumulated_loss = 0
|
accumulated_loss = 0
|
||||||
correct_sum = 0
|
correct_sum = 0
|
||||||
total_sum = 0
|
total_sum = 0
|
||||||
|
num_steps = len(test_dataloader)
|
||||||
|
data_iter = iter(test_dataloader)
|
||||||
|
|
||||||
for i in range(engine.schedule.num_steps):
|
for i in range(num_steps):
|
||||||
output, label, loss = engine.step()
|
output, label, loss = engine.step(data_iter)
|
||||||
|
|
||||||
if gpc.is_last_rank(ParallelMode.PIPELINE):
|
if gpc.is_last_rank(ParallelMode.PIPELINE):
|
||||||
accumulated_loss += loss.detach().cpu().numpy()
|
accumulated_loss += loss.detach().cpu().numpy()
|
||||||
|
@ -43,21 +45,23 @@ def eval(engine):
|
||||||
correct = torch.sum(label == output)
|
correct = torch.sum(label == output)
|
||||||
correct_sum += correct
|
correct_sum += correct
|
||||||
total_sum += label.size(0)
|
total_sum += label.size(0)
|
||||||
avg_loss = accumulated_loss / engine.schedule.num_steps
|
avg_loss = accumulated_loss / num_steps
|
||||||
return correct_sum, total_sum, avg_loss
|
return correct_sum, total_sum, avg_loss
|
||||||
|
|
||||||
|
|
||||||
def train(engine):
|
def train(engine, train_dataloader):
|
||||||
engine.train()
|
engine.train()
|
||||||
accumulated_loss = 0
|
accumulated_loss = 0
|
||||||
|
num_steps = len(train_dataloader)
|
||||||
|
data_iter = iter(train_dataloader)
|
||||||
|
|
||||||
|
for i in range(num_steps):
|
||||||
|
output, label, loss = engine.step(data_iter)
|
||||||
|
|
||||||
for i in range(engine.schedule.num_steps):
|
|
||||||
output, label, loss = engine.step()
|
|
||||||
|
|
||||||
if gpc.is_last_rank(ParallelMode.PIPELINE):
|
if gpc.is_last_rank(ParallelMode.PIPELINE):
|
||||||
accumulated_loss += loss.detach().cpu().numpy()
|
accumulated_loss += loss.detach().cpu().numpy()
|
||||||
|
|
||||||
avg_loss = accumulated_loss / engine.schedule.num_steps
|
avg_loss = accumulated_loss / num_steps
|
||||||
return avg_loss
|
return avg_loss
|
||||||
|
|
||||||
|
|
||||||
|
@ -65,25 +69,16 @@ def train(engine):
|
||||||
@pytest.mark.skip("This test should be invoked by test.sh in the same folder as it runs on multiple gpus")
|
@pytest.mark.skip("This test should be invoked by test.sh in the same folder as it runs on multiple gpus")
|
||||||
def test_2p5d_parallel_vision_transformer():
|
def test_2p5d_parallel_vision_transformer():
|
||||||
# init dist
|
# init dist
|
||||||
model, train_dataloader, test_dataloader, criterion, optimizer, schedule, lr_scheduler = colossalai.initialize(
|
engine, train_dataloader, test_dataloader = colossalai.initialize(CONFIG_PATH)
|
||||||
CONFIG_PATH)
|
|
||||||
logger = get_global_dist_logger()
|
logger = get_global_dist_logger()
|
||||||
|
|
||||||
engine = Engine(model=model,
|
|
||||||
train_dataloader=train_dataloader,
|
|
||||||
test_dataloader=test_dataloader,
|
|
||||||
criterion=criterion,
|
|
||||||
optimizer=optimizer,
|
|
||||||
lr_scheduler=lr_scheduler,
|
|
||||||
schedule=schedule)
|
|
||||||
|
|
||||||
for epoch in range(gpc.config.num_epochs):
|
for epoch in range(gpc.config.num_epochs):
|
||||||
train_loss = train(engine)
|
train_loss = train(engine, train_dataloader)
|
||||||
if gpc.is_last_rank(ParallelMode.PIPELINE):
|
if gpc.is_last_rank(ParallelMode.PIPELINE):
|
||||||
logger.info(f'epoch {epoch} - train loss: {train_loss}')
|
logger.info(f'epoch {epoch} - train loss: {train_loss}')
|
||||||
|
|
||||||
if epoch % 2 == 0:
|
if epoch % 2 == 0:
|
||||||
correct_sum, total_sum, eval_loss = eval(engine)
|
correct_sum, total_sum, eval_loss = eval(engine, test_dataloader)
|
||||||
if gpc.is_last_rank(ParallelMode.PIPELINE):
|
if gpc.is_last_rank(ParallelMode.PIPELINE):
|
||||||
logger.info(
|
logger.info(
|
||||||
f'epoch {epoch} - eval loss: {eval_loss}, total: {total_sum}, '
|
f'epoch {epoch} - eval loss: {eval_loss}, total: {total_sum}, '
|
||||||
|
@ -91,4 +86,4 @@ def test_2p5d_parallel_vision_transformer():
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
test_2p5d_parallel_vision_transformer()
|
test_2p5d_parallel_vision_transformer()
|
||||||
|
|
|
@ -38,5 +38,3 @@ optimizer = dict(type='Adam', lr=0.001)
|
||||||
|
|
||||||
loss = dict(type='CrossEntropyLoss')
|
loss = dict(type='CrossEntropyLoss')
|
||||||
|
|
||||||
# set_device_func = lambda global_rank, world_size: global_rank % 4
|
|
||||||
seed = 1024
|
|
||||||
|
|
|
@ -40,6 +40,3 @@ optimizer = dict(type='Adam', lr=0.001)
|
||||||
|
|
||||||
loss = dict(type='CrossEntropyLoss')
|
loss = dict(type='CrossEntropyLoss')
|
||||||
fp16 = dict(mode=AMP_TYPE.APEX)
|
fp16 = dict(mode=AMP_TYPE.APEX)
|
||||||
|
|
||||||
# set_device_func = lambda global_rank, world_size: global_rank % 4
|
|
||||||
seed = 1024
|
|
||||||
|
|
|
@ -40,6 +40,3 @@ optimizer = dict(type='Adam', lr=0.001)
|
||||||
|
|
||||||
loss = dict(type='CrossEntropyLoss')
|
loss = dict(type='CrossEntropyLoss')
|
||||||
fp16 = dict(mode=AMP_TYPE.TORCH)
|
fp16 = dict(mode=AMP_TYPE.TORCH)
|
||||||
|
|
||||||
# set_device_func = lambda global_rank, world_size: global_rank % 4
|
|
||||||
seed = 1024
|
|
||||||
|
|
|
@ -38,11 +38,9 @@ parallel = dict(
|
||||||
tensor=dict(size=1, mode=None)
|
tensor=dict(size=1, mode=None)
|
||||||
)
|
)
|
||||||
|
|
||||||
schedule = dict(
|
engine = dict(
|
||||||
num_microbatches=4
|
schedule=dict(
|
||||||
|
num_microbatches=4
|
||||||
|
)
|
||||||
)
|
)
|
||||||
num_pipeling_batches = 2
|
|
||||||
seed = 1024
|
|
||||||
lr_scheduler = dict(type='LinearWarmupLR', warmup_steps=5)
|
|
||||||
|
|
||||||
num_epochs = 10
|
num_epochs = 10
|
||||||
|
|
|
@ -8,7 +8,6 @@ import torch
|
||||||
|
|
||||||
from colossalai import initialize
|
from colossalai import initialize
|
||||||
from colossalai.core import global_context as gpc
|
from colossalai.core import global_context as gpc
|
||||||
from colossalai.engine import Engine
|
|
||||||
from colossalai.logging import get_global_dist_logger
|
from colossalai.logging import get_global_dist_logger
|
||||||
from colossalai.utils import report_memory_usage
|
from colossalai.utils import report_memory_usage
|
||||||
|
|
||||||
|
@ -24,20 +23,13 @@ NO_PIPE_CONFIG_PATH = osp.join(DIR_PATH, '../configs/non_pipeline_resnet_apex_am
|
||||||
|
|
||||||
|
|
||||||
def run_no_pipeline(config):
|
def run_no_pipeline(config):
|
||||||
model, train_dataloader, test_dataloader, criterion, optimizer, schedule, lr_scheduler = initialize(config)
|
engine, train_dataloader, test_dataloader = initialize(config)
|
||||||
logger = get_global_dist_logger()
|
logger = get_global_dist_logger()
|
||||||
rank = torch.distributed.get_rank()
|
rank = torch.distributed.get_rank()
|
||||||
|
|
||||||
engine = Engine(model=model,
|
|
||||||
train_dataloader=train_dataloader,
|
|
||||||
criterion=criterion,
|
|
||||||
optimizer=optimizer,
|
|
||||||
schedule=schedule)
|
|
||||||
engine.train()
|
engine.train()
|
||||||
logger.info('lr = %g' % engine.get_lr())
|
output, label, loss = engine.step(iter(train_dataloader))
|
||||||
output, label, loss = engine.step()
|
|
||||||
logger.info('Rank {} returns: {}'.format(rank, loss.item()))
|
logger.info('Rank {} returns: {}'.format(rank, loss.item()))
|
||||||
logger.info('lr = %g' % engine.get_lr())
|
|
||||||
|
|
||||||
gpc.destroy()
|
gpc.destroy()
|
||||||
logger.info('Test engine finished')
|
logger.info('Test engine finished')
|
||||||
|
|
|
@ -8,7 +8,6 @@ import torch
|
||||||
|
|
||||||
from colossalai import initialize
|
from colossalai import initialize
|
||||||
from colossalai.core import global_context as gpc
|
from colossalai.core import global_context as gpc
|
||||||
from colossalai.engine import Engine
|
|
||||||
from colossalai.logging import get_global_dist_logger
|
from colossalai.logging import get_global_dist_logger
|
||||||
from colossalai.utils import report_memory_usage
|
from colossalai.utils import report_memory_usage
|
||||||
|
|
||||||
|
@ -26,21 +25,14 @@ NO_PIPE_CONFIG_PATH = osp.join(DIR_PATH, '../configs/non_pipeline_resnet.py')
|
||||||
def test_no_pipeline(config):
|
def test_no_pipeline(config):
|
||||||
print('Test no pipeline engine start')
|
print('Test no pipeline engine start')
|
||||||
|
|
||||||
model, train_dataloader, test_dataloader, criterion, optimizer, schedule, lr_scheduler = initialize(config)
|
engine, train_dataloader, test_dataloader = initialize(config)
|
||||||
logger = get_global_dist_logger()
|
logger = get_global_dist_logger()
|
||||||
|
|
||||||
rank = torch.distributed.get_rank()
|
rank = torch.distributed.get_rank()
|
||||||
engine = Engine(model=model,
|
|
||||||
train_dataloader=train_dataloader,
|
|
||||||
criterion=criterion,
|
|
||||||
optimizer=optimizer,
|
|
||||||
schedule=schedule)
|
|
||||||
|
|
||||||
engine.train()
|
engine.train()
|
||||||
logger.info('lr = %g' % engine.get_lr())
|
output, label, loss = engine.step(iter(train_dataloader))
|
||||||
output, label, loss = engine.step()
|
|
||||||
logger.info('Rank {} returns: {}'.format(rank, loss.item()))
|
logger.info('Rank {} returns: {}'.format(rank, loss.item()))
|
||||||
logger.info('lr = %g' % engine.get_lr())
|
|
||||||
|
|
||||||
gpc.destroy()
|
gpc.destroy()
|
||||||
logger.info('Test engine finished')
|
logger.info('Test engine finished')
|
||||||
|
|
|
@ -8,7 +8,6 @@ import torch
|
||||||
|
|
||||||
from colossalai import initialize
|
from colossalai import initialize
|
||||||
from colossalai.core import global_context as gpc
|
from colossalai.core import global_context as gpc
|
||||||
from colossalai.engine import Engine
|
|
||||||
from colossalai.logging import get_global_dist_logger
|
from colossalai.logging import get_global_dist_logger
|
||||||
from colossalai.utils import report_memory_usage
|
from colossalai.utils import report_memory_usage
|
||||||
|
|
||||||
|
@ -26,21 +25,13 @@ NO_PIPE_CONFIG_PATH = osp.join(DIR_PATH, '../configs/non_pipeline_resnet_torch_a
|
||||||
def test_no_pipeline(config):
|
def test_no_pipeline(config):
|
||||||
print('Test no pipeline engine start')
|
print('Test no pipeline engine start')
|
||||||
|
|
||||||
model, train_dataloader, test_dataloader, criterion, optimizer, schedule, lr_scheduler = initialize(config)
|
engine, train_dataloader, test_dataloader = initialize(config)
|
||||||
logger = get_global_dist_logger()
|
logger = get_global_dist_logger()
|
||||||
|
|
||||||
rank = torch.distributed.get_rank()
|
rank = torch.distributed.get_rank()
|
||||||
engine = Engine(model=model,
|
|
||||||
train_dataloader=train_dataloader,
|
|
||||||
criterion=criterion,
|
|
||||||
optimizer=optimizer,
|
|
||||||
schedule=schedule)
|
|
||||||
|
|
||||||
engine.train()
|
engine.train()
|
||||||
logger.info('lr = %g' % engine.get_lr())
|
output, label, loss = engine.step(iter(train_dataloader))
|
||||||
output, label, loss = engine.step()
|
|
||||||
logger.info('Rank {} returns: {}'.format(rank, loss.item()))
|
logger.info('Rank {} returns: {}'.format(rank, loss.item()))
|
||||||
logger.info('lr = %g' % engine.get_lr())
|
|
||||||
|
|
||||||
gpc.destroy()
|
gpc.destroy()
|
||||||
logger.info('Test engine finished')
|
logger.info('Test engine finished')
|
||||||
|
|
|
@ -5,6 +5,7 @@ import os.path as osp
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
from colossalai.context import ParallelMode
|
||||||
from colossalai.core import global_context as gpc
|
from colossalai.core import global_context as gpc
|
||||||
from colossalai.initialize import initialize
|
from colossalai.initialize import initialize
|
||||||
from colossalai.logging import get_global_dist_logger
|
from colossalai.logging import get_global_dist_logger
|
||||||
|
@ -22,13 +23,25 @@ CONFIG_PATH = osp.join(DIR_PATH, '../configs/pipeline_vanilla_resnet.py')
|
||||||
@pytest.mark.skip("This test should be invoked using the test.sh provided")
|
@pytest.mark.skip("This test should be invoked using the test.sh provided")
|
||||||
@pytest.mark.dist
|
@pytest.mark.dist
|
||||||
def test_schedule():
|
def test_schedule():
|
||||||
model, train_dataloader, test_dataloader, criterion, optimizer, schedule, lr_scheduler = initialize(CONFIG_PATH)
|
engine, train_dataloader, test_dataloader = initialize(CONFIG_PATH)
|
||||||
logger = get_global_dist_logger()
|
logger = get_global_dist_logger()
|
||||||
|
|
||||||
schedule.zero_grad()
|
model = engine.model
|
||||||
output, label, losses = schedule.forward_backward_step(forward_only=False)
|
optimizer = engine.optimizer
|
||||||
schedule.step()
|
criterion = engine.criterion
|
||||||
logger.info('losses: {}'.format([loss.item() for loss in losses]))
|
schedule = engine._schedule
|
||||||
|
|
||||||
|
output, label, loss = schedule.forward_backward_step(
|
||||||
|
data_iter=iter(train_dataloader),
|
||||||
|
model=model,
|
||||||
|
optimizer=optimizer,
|
||||||
|
criterion=criterion,
|
||||||
|
forward_only=False
|
||||||
|
)
|
||||||
|
schedule.optimizer_step(model, optimizer)
|
||||||
|
|
||||||
|
if gpc.is_last_rank(ParallelMode.PIPELINE):
|
||||||
|
logger.info('losses: {}'.format(loss))
|
||||||
|
|
||||||
gpc.destroy()
|
gpc.destroy()
|
||||||
logger.info('training finished')
|
logger.info('training finished')
|
||||||
|
|
|
@ -9,7 +9,6 @@ import torch
|
||||||
from colossalai import initialize
|
from colossalai import initialize
|
||||||
from colossalai.context import ParallelMode
|
from colossalai.context import ParallelMode
|
||||||
from colossalai.core import global_context as gpc
|
from colossalai.core import global_context as gpc
|
||||||
from colossalai.engine import Engine
|
|
||||||
from colossalai.logging import get_global_dist_logger
|
from colossalai.logging import get_global_dist_logger
|
||||||
|
|
||||||
NUM_BATCH = 128
|
NUM_BATCH = 128
|
||||||
|
@ -23,22 +22,14 @@ PIPE_CONFIG_PATH = osp.join(DIR_PATH, '../configs/pipeline_vanilla_resnet.py')
|
||||||
|
|
||||||
|
|
||||||
def run_pipeline(config):
|
def run_pipeline(config):
|
||||||
model, train_dataloader, test_dataloader, criterion, optimizer, schedule, lr_scheduler = initialize(config)
|
engine, train_dataloader, test_dataloader = initialize(config)
|
||||||
logger = get_global_dist_logger()
|
logger = get_global_dist_logger()
|
||||||
rank = torch.distributed.get_rank()
|
rank = torch.distributed.get_rank()
|
||||||
engine = Engine(model=model,
|
|
||||||
train_dataloader=train_dataloader,
|
|
||||||
criterion=criterion,
|
|
||||||
optimizer=optimizer,
|
|
||||||
lr_scheduler=lr_scheduler,
|
|
||||||
schedule=schedule)
|
|
||||||
|
|
||||||
engine.train()
|
engine.train()
|
||||||
logger.info('lr = %g' % engine.get_lr())
|
outputs, labels, loss = engine.step(iter(train_dataloader))
|
||||||
outputs, labels, loss = engine.step()
|
|
||||||
if gpc.is_last_rank(ParallelMode.PIPELINE):
|
if gpc.is_last_rank(ParallelMode.PIPELINE):
|
||||||
logger.info('losses: {}'.format(rank, loss.item()))
|
logger.info('losses: {}'.format(rank, loss.item()))
|
||||||
logger.info('lr = %g' % engine.get_lr())
|
|
||||||
|
|
||||||
gpc.destroy()
|
gpc.destroy()
|
||||||
logger.info('Test engine pipeline finished')
|
logger.info('Test engine pipeline finished')
|
||||||
|
|
|
@ -132,9 +132,12 @@ fp16 = dict(
|
||||||
initial_scale=2 ** 4
|
initial_scale=2 ** 4
|
||||||
)
|
)
|
||||||
|
|
||||||
|
num_epochs = 60
|
||||||
|
|
||||||
|
|
||||||
lr_scheduler = dict(
|
lr_scheduler = dict(
|
||||||
type='LinearWarmupLR',
|
type='LinearWarmupLR',
|
||||||
warmup_epochs=5
|
warmup_steps=5,
|
||||||
|
total_steps=num_epochs
|
||||||
)
|
)
|
||||||
|
|
||||||
num_epochs = 60
|
|
||||||
|
|
|
@ -7,23 +7,25 @@ import pytest
|
||||||
import torch.autograd
|
import torch.autograd
|
||||||
|
|
||||||
import colossalai
|
import colossalai
|
||||||
|
from colossalai.builder import build_lr_scheduler
|
||||||
from colossalai.context.parallel_mode import ParallelMode
|
from colossalai.context.parallel_mode import ParallelMode
|
||||||
from colossalai.core import global_context as gpc
|
from colossalai.core import global_context as gpc
|
||||||
from colossalai.engine import Engine
|
|
||||||
from colossalai.logging import get_global_dist_logger
|
from colossalai.logging import get_global_dist_logger
|
||||||
from colossalai.nn.layer._parallel_utilities import _gather
|
from colossalai.nn.layer._parallel_utilities import _gather
|
||||||
|
|
||||||
CONFIG_PATH = Path(__file__).parent.parent.joinpath('configs/vit_2d.py')
|
CONFIG_PATH = Path(__file__).parent.parent.joinpath('configs/vit_2d.py')
|
||||||
|
|
||||||
|
|
||||||
def eval(engine):
|
def eval(engine, test_dataloader):
|
||||||
engine.eval()
|
engine.eval()
|
||||||
accumulated_loss = 0
|
accumulated_loss = 0
|
||||||
correct_sum = 0
|
correct_sum = 0
|
||||||
total_sum = 0
|
total_sum = 0
|
||||||
|
num_steps = len(test_dataloader)
|
||||||
|
data_iter = iter(test_dataloader)
|
||||||
|
|
||||||
for i in range(engine.schedule.num_steps):
|
for i in range(num_steps):
|
||||||
output, label, loss = engine.step()
|
output, label, loss = engine.step(data_iter)
|
||||||
accumulated_loss += loss.detach().cpu().numpy()
|
accumulated_loss += loss.detach().cpu().numpy()
|
||||||
|
|
||||||
output = _gather(
|
output = _gather(
|
||||||
|
@ -40,18 +42,21 @@ def eval(engine):
|
||||||
correct = torch.sum(label[0] == output)
|
correct = torch.sum(label[0] == output)
|
||||||
correct_sum += correct
|
correct_sum += correct
|
||||||
total_sum += label[0].size(0)
|
total_sum += label[0].size(0)
|
||||||
avg_loss = accumulated_loss / engine.schedule.num_steps
|
avg_loss = accumulated_loss / num_steps
|
||||||
return correct_sum, total_sum, avg_loss
|
return correct_sum, total_sum, avg_loss
|
||||||
|
|
||||||
|
|
||||||
def train(engine):
|
def train(engine, train_dataloader, lr_scheduler):
|
||||||
engine.train()
|
engine.train()
|
||||||
accumulated_loss = 0
|
accumulated_loss = 0
|
||||||
|
num_steps = len(train_dataloader)
|
||||||
|
data_iter = iter(train_dataloader)
|
||||||
|
|
||||||
for i in range(engine.schedule.num_steps):
|
for i in range(num_steps):
|
||||||
output, label, loss = engine.step()
|
output, label, loss = engine.step(data_iter)
|
||||||
accumulated_loss += loss.squeeze(0).detach().cpu().numpy()
|
accumulated_loss += loss.squeeze(0).detach().cpu().numpy()
|
||||||
avg_loss = accumulated_loss / engine.schedule.num_steps
|
avg_loss = accumulated_loss / num_steps
|
||||||
|
lr_scheduler.step()
|
||||||
return avg_loss
|
return avg_loss
|
||||||
|
|
||||||
|
|
||||||
|
@ -59,26 +64,18 @@ def train(engine):
|
||||||
@pytest.mark.skip("This test should be invoked by test.sh in the same folder as it runs on multiple gpus")
|
@pytest.mark.skip("This test should be invoked by test.sh in the same folder as it runs on multiple gpus")
|
||||||
def test_2d_parallel_vision_transformer():
|
def test_2d_parallel_vision_transformer():
|
||||||
# init dist
|
# init dist
|
||||||
model, train_dataloader, test_dataloader, criterion, optimizer, schedule, lr_scheduler = colossalai.initialize(
|
engine, train_dataloader, test_dataloader = colossalai.initialize(CONFIG_PATH)
|
||||||
CONFIG_PATH)
|
lr_scheduler = build_lr_scheduler(gpc.config.lr_scheduler, engine.optimizer)
|
||||||
logger = get_global_dist_logger()
|
logger = get_global_dist_logger()
|
||||||
|
|
||||||
engine = Engine(model=model,
|
|
||||||
train_dataloader=train_dataloader,
|
|
||||||
test_dataloader=test_dataloader,
|
|
||||||
criterion=criterion,
|
|
||||||
optimizer=optimizer,
|
|
||||||
lr_scheduler=lr_scheduler,
|
|
||||||
schedule=schedule)
|
|
||||||
|
|
||||||
logger.info('start training')
|
logger.info('start training')
|
||||||
for epoch in range(gpc.config.num_epochs):
|
for epoch in range(gpc.config.num_epochs):
|
||||||
train_loss = train(engine)
|
train_loss = train(engine, train_dataloader, lr_scheduler)
|
||||||
|
|
||||||
logger.info(f'epoch {epoch} - train loss: {train_loss}')
|
logger.info(f'epoch {epoch} - train loss: {train_loss}')
|
||||||
|
|
||||||
if epoch % 2 == 0:
|
if epoch % 2 == 0:
|
||||||
correct_sum, total_sum, eval_loss = eval(engine)
|
correct_sum, total_sum, eval_loss = eval(engine, test_dataloader)
|
||||||
logger.info(
|
logger.info(
|
||||||
f'epoch {epoch} - eval loss: {eval_loss}, total: {total_sum}, '
|
f'epoch {epoch} - eval loss: {eval_loss}, total: {total_sum}, '
|
||||||
f'correct: {correct_sum}, acc: {correct_sum / total_sum}')
|
f'correct: {correct_sum}, acc: {correct_sum / total_sum}')
|
||||||
|
|
|
@ -102,6 +102,6 @@ parallel = dict(
|
||||||
tensor=dict(size=4, mode='2d'),
|
tensor=dict(size=4, mode='2d'),
|
||||||
)
|
)
|
||||||
|
|
||||||
lr_scheduler = dict(type='LinearWarmupLR', warmup_epochs=5)
|
|
||||||
|
|
||||||
num_epochs = 60
|
num_epochs = 60
|
||||||
|
|
||||||
|
lr_scheduler = dict(type='LinearWarmupLR', warmup_steps=5, total_steps=num_epochs)
|
||||||
|
|
|
@ -125,13 +125,6 @@ parallel = dict(
|
||||||
tensor=dict(size=4, depth=1, mode='2.5d'),
|
tensor=dict(size=4, depth=1, mode='2.5d'),
|
||||||
)
|
)
|
||||||
|
|
||||||
lr_scheduler = dict(
|
|
||||||
type='LinearWarmupLR',
|
|
||||||
warmup_epochs=5
|
|
||||||
)
|
|
||||||
|
|
||||||
schedule = dict(
|
|
||||||
num_microbatches=8
|
|
||||||
)
|
|
||||||
|
|
||||||
num_epochs = 60
|
num_epochs = 60
|
||||||
|
|
||||||
|
lr_scheduler = dict(type='LinearWarmupLR', warmup_steps=5, total_steps=num_epochs)
|
||||||
|
|
|
@ -116,9 +116,14 @@ hooks = [
|
||||||
weight_parallel_mode=ParallelMode.PARALLEL_3D_WEIGHT,
|
weight_parallel_mode=ParallelMode.PARALLEL_3D_WEIGHT,
|
||||||
),
|
),
|
||||||
dict(type='LossHook'),
|
dict(type='LossHook'),
|
||||||
# dict(type='TensorboardHook', log_dir='./tfb_logs'),
|
dict(
|
||||||
# dict(type='SaveCheckpointHook', interval=5, checkpoint_dir='./ckpt'),
|
type='LRSchedulerHook',
|
||||||
# dict(type='LoadCheckpointHook', epoch=20, checkpoint_dir='./ckpt')
|
by_epoch=True,
|
||||||
|
lr_scheduler_cfg=dict(
|
||||||
|
type='LinearWarmupLR',
|
||||||
|
warmup_steps=5
|
||||||
|
)
|
||||||
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
parallel = dict(
|
parallel = dict(
|
||||||
|
@ -127,12 +132,4 @@ parallel = dict(
|
||||||
tensor=dict(mode='3d', size=8),
|
tensor=dict(mode='3d', size=8),
|
||||||
)
|
)
|
||||||
|
|
||||||
# fp16 = dict(mode=AMP_TYPE.PARALLEL, initial_scale=2 ** 6)
|
|
||||||
|
|
||||||
lr_scheduler = dict(type='LinearWarmupLR', warmup_epochs=5)
|
|
||||||
|
|
||||||
# schedule = dict(num_microbatches=4)
|
|
||||||
|
|
||||||
num_epochs = 60
|
num_epochs = 60
|
||||||
|
|
||||||
seed = 42
|
|
||||||
|
|
|
@ -7,23 +7,25 @@ import pytest
|
||||||
import torch.autograd
|
import torch.autograd
|
||||||
|
|
||||||
import colossalai
|
import colossalai
|
||||||
|
from colossalai.builder import build_lr_scheduler
|
||||||
from colossalai.context.parallel_mode import ParallelMode
|
from colossalai.context.parallel_mode import ParallelMode
|
||||||
from colossalai.core import global_context as gpc
|
from colossalai.core import global_context as gpc
|
||||||
from colossalai.engine import Engine
|
|
||||||
from colossalai.logging import get_global_dist_logger
|
from colossalai.logging import get_global_dist_logger
|
||||||
from colossalai.nn.layer._parallel_utilities import _gather
|
from colossalai.nn.layer._parallel_utilities import _gather
|
||||||
|
|
||||||
CONFIG_PATH = Path(__file__).parent.parent.joinpath('configs/vit_2d.py')
|
CONFIG_PATH = Path(__file__).parent.parent.joinpath('configs/vit_2d.py')
|
||||||
|
|
||||||
|
|
||||||
def eval(engine):
|
def eval(engine, test_dataloader):
|
||||||
engine.eval()
|
engine.eval()
|
||||||
accumulated_loss = 0
|
accumulated_loss = 0
|
||||||
correct_sum = 0
|
correct_sum = 0
|
||||||
total_sum = 0
|
total_sum = 0
|
||||||
|
num_steps = len(test_dataloader)
|
||||||
|
data_iter = iter(test_dataloader)
|
||||||
|
|
||||||
for i in range(engine.schedule.num_steps):
|
for i in range(num_steps):
|
||||||
output, label, loss = engine.step()
|
output, label, loss = engine.step(data_iter)
|
||||||
accumulated_loss += loss.detach().cpu().numpy()
|
accumulated_loss += loss.detach().cpu().numpy()
|
||||||
|
|
||||||
output = _gather(
|
output = _gather(
|
||||||
|
@ -40,18 +42,21 @@ def eval(engine):
|
||||||
correct = torch.sum(label[0] == output)
|
correct = torch.sum(label[0] == output)
|
||||||
correct_sum += correct
|
correct_sum += correct
|
||||||
total_sum += label[0].size(0)
|
total_sum += label[0].size(0)
|
||||||
avg_loss = accumulated_loss / engine.schedule.num_steps
|
avg_loss = accumulated_loss / num_steps
|
||||||
return correct_sum, total_sum, avg_loss
|
return correct_sum, total_sum, avg_loss
|
||||||
|
|
||||||
|
|
||||||
def train(engine):
|
def train(engine, train_dataloader, lr_scheduler):
|
||||||
engine.train()
|
engine.train()
|
||||||
accumulated_loss = 0
|
accumulated_loss = 0
|
||||||
|
num_steps = len(train_dataloader)
|
||||||
|
data_iter = iter(train_dataloader)
|
||||||
|
|
||||||
for i in range(engine.schedule.num_steps):
|
for i in range(num_steps):
|
||||||
output, label, loss = engine.step()
|
output, label, loss = engine.step(data_iter)
|
||||||
accumulated_loss += loss.detach().cpu().numpy()
|
accumulated_loss += loss.detach().cpu().numpy()
|
||||||
avg_loss = accumulated_loss / engine.schedule.num_steps
|
avg_loss = accumulated_loss / num_steps
|
||||||
|
lr_scheduler.step()
|
||||||
return avg_loss
|
return avg_loss
|
||||||
|
|
||||||
|
|
||||||
|
@ -59,25 +64,17 @@ def train(engine):
|
||||||
@pytest.mark.skip("This test should be invoked by test.sh in the same folder as it runs on multiple gpus")
|
@pytest.mark.skip("This test should be invoked by test.sh in the same folder as it runs on multiple gpus")
|
||||||
def test_2d_parallel_vision_transformer():
|
def test_2d_parallel_vision_transformer():
|
||||||
# init dist
|
# init dist
|
||||||
model, train_dataloader, test_dataloader, criterion, optimizer, schedule, lr_scheduler = colossalai.initialize(
|
engine, train_dataloader, test_dataloader = colossalai.initialize(CONFIG_PATH)
|
||||||
CONFIG_PATH)
|
lr_scheduler = build_lr_scheduler(gpc.config.lr_scheduler, engine.optimizer)
|
||||||
logger = get_global_dist_logger()
|
logger = get_global_dist_logger()
|
||||||
|
|
||||||
engine = Engine(model=model,
|
|
||||||
train_dataloader=train_dataloader,
|
|
||||||
test_dataloader=test_dataloader,
|
|
||||||
criterion=criterion,
|
|
||||||
optimizer=optimizer,
|
|
||||||
lr_scheduler=lr_scheduler,
|
|
||||||
schedule=schedule)
|
|
||||||
|
|
||||||
logger.info('start training')
|
logger.info('start training')
|
||||||
for epoch in range(gpc.config.num_epochs):
|
for epoch in range(gpc.config.num_epochs):
|
||||||
train_loss = train(engine)
|
train_loss = train(engine, train_dataloader, lr_scheduler)
|
||||||
logger.info(f'epoch {epoch} - train loss: {train_loss}')
|
logger.info(f'epoch {epoch} - train loss: {train_loss}')
|
||||||
|
|
||||||
if epoch % 2 == 0:
|
if epoch % 2 == 0:
|
||||||
correct_sum, total_sum, eval_loss = eval(engine)
|
correct_sum, total_sum, eval_loss = eval(engine, test_dataloader)
|
||||||
logger.info(
|
logger.info(
|
||||||
f'epoch {epoch} - eval loss: {eval_loss}, total: {total_sum}, '
|
f'epoch {epoch} - eval loss: {eval_loss}, total: {total_sum}, '
|
||||||
f'correct: {correct_sum}, acc: {correct_sum / total_sum}')
|
f'correct: {correct_sum}, acc: {correct_sum / total_sum}')
|
||||||
|
|
|
@ -4,22 +4,25 @@ import pytest
|
||||||
import torch.autograd
|
import torch.autograd
|
||||||
|
|
||||||
import colossalai
|
import colossalai
|
||||||
|
from colossalai.builder import build_lr_scheduler
|
||||||
from colossalai.context.parallel_mode import ParallelMode
|
from colossalai.context.parallel_mode import ParallelMode
|
||||||
from colossalai.core import global_context as gpc
|
from colossalai.core import global_context as gpc
|
||||||
from colossalai.engine import Engine
|
|
||||||
from colossalai.logging import get_global_dist_logger
|
from colossalai.logging import get_global_dist_logger
|
||||||
from colossalai.nn.layer._parallel_utilities import _gather
|
from colossalai.nn.layer._parallel_utilities import _gather
|
||||||
|
|
||||||
CONFIG_PATH = Path(__file__).parent.parent.joinpath('configs/vit_2p5d.py')
|
CONFIG_PATH = Path(__file__).parent.parent.joinpath('configs/vit_2p5d.py')
|
||||||
|
|
||||||
def eval(engine):
|
|
||||||
|
def eval(engine, test_dataloader):
|
||||||
engine.eval()
|
engine.eval()
|
||||||
accumulated_loss = 0
|
accumulated_loss = 0
|
||||||
correct_sum = 0
|
correct_sum = 0
|
||||||
total_sum = 0
|
total_sum = 0
|
||||||
|
num_steps = len(test_dataloader)
|
||||||
|
data_iter = iter(test_dataloader)
|
||||||
|
|
||||||
for i in range(engine.schedule.num_steps):
|
for i in range(num_steps):
|
||||||
output, label, loss = engine.step()
|
output, label, loss = engine.step(data_iter)
|
||||||
accumulated_loss += loss.detach().cpu().numpy()
|
accumulated_loss += loss.detach().cpu().numpy()
|
||||||
|
|
||||||
output = _gather(
|
output = _gather(
|
||||||
|
@ -41,18 +44,21 @@ def eval(engine):
|
||||||
correct = torch.sum(label[0] == output)
|
correct = torch.sum(label[0] == output)
|
||||||
correct_sum += correct
|
correct_sum += correct
|
||||||
total_sum += label[0].size(0)
|
total_sum += label[0].size(0)
|
||||||
avg_loss = accumulated_loss / engine.schedule.num_steps
|
avg_loss = accumulated_loss / num_steps
|
||||||
return correct_sum, total_sum, avg_loss
|
return correct_sum, total_sum, avg_loss
|
||||||
|
|
||||||
|
|
||||||
def train(engine):
|
def train(engine, train_dataloader, lr_scheduler):
|
||||||
engine.train()
|
engine.train()
|
||||||
accumulated_loss = 0
|
accumulated_loss = 0
|
||||||
|
num_steps = len(train_dataloader)
|
||||||
|
data_iter = iter(train_dataloader)
|
||||||
|
|
||||||
for i in range(engine.schedule.num_steps):
|
for i in range(num_steps):
|
||||||
output, label, loss = engine.step()
|
output, label, loss = engine.step(data_iter)
|
||||||
accumulated_loss += loss.detach().cpu().numpy()
|
accumulated_loss += loss.detach().cpu().numpy()
|
||||||
avg_loss = accumulated_loss / engine.schedule.num_steps
|
avg_loss = accumulated_loss / num_steps
|
||||||
|
lr_scheduler.step()
|
||||||
return avg_loss
|
return avg_loss
|
||||||
|
|
||||||
|
|
||||||
|
@ -60,29 +66,21 @@ def train(engine):
|
||||||
@pytest.mark.skip("This test should be invoked by test.sh in the same folder as it runs on multiple gpus")
|
@pytest.mark.skip("This test should be invoked by test.sh in the same folder as it runs on multiple gpus")
|
||||||
def test_2p5d_parallel_vision_transformer():
|
def test_2p5d_parallel_vision_transformer():
|
||||||
# init dist
|
# init dist
|
||||||
model, train_dataloader, test_dataloader, criterion, optimizer, schedule, lr_scheduler = colossalai.initialize(
|
engine, train_dataloader, test_dataloader = colossalai.initialize(CONFIG_PATH)
|
||||||
CONFIG_PATH)
|
lr_scheduler = build_lr_scheduler(gpc.config.lr_scheduler, engine.optimizer)
|
||||||
logger = get_global_dist_logger()
|
logger = get_global_dist_logger()
|
||||||
|
|
||||||
engine = Engine(model=model,
|
|
||||||
train_dataloader=train_dataloader,
|
|
||||||
test_dataloader=test_dataloader,
|
|
||||||
criterion=criterion,
|
|
||||||
optimizer=optimizer,
|
|
||||||
lr_scheduler=lr_scheduler,
|
|
||||||
schedule=schedule)
|
|
||||||
|
|
||||||
logger.info('start training')
|
logger.info('start training')
|
||||||
for epoch in range(gpc.config.num_epochs):
|
for epoch in range(gpc.config.num_epochs):
|
||||||
train_loss = train(engine)
|
train_loss = train(engine, train_dataloader, lr_scheduler)
|
||||||
logger.info(f'epoch {epoch} - train loss: {train_loss}')
|
logger.info(f'epoch {epoch} - train loss: {train_loss}')
|
||||||
|
|
||||||
if epoch % 2 == 0:
|
if epoch % 2 == 0:
|
||||||
correct_sum, total_sum, eval_loss = eval(engine)
|
correct_sum, total_sum, eval_loss = eval(engine, test_dataloader)
|
||||||
logger.info(
|
logger.info(
|
||||||
f'epoch {epoch} - eval loss: {eval_loss}, total: {total_sum}, '
|
f'epoch {epoch} - eval loss: {eval_loss}, total: {total_sum}, '
|
||||||
f'correct: {correct_sum}, acc: {correct_sum / total_sum}')
|
f'correct: {correct_sum}, acc: {correct_sum / total_sum}')
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
test_2p5d_parallel_vision_transformer()
|
test_2p5d_parallel_vision_transformer()
|
||||||
|
|
|
@ -1,16 +1,14 @@
|
||||||
#!/usr/bin/env python
|
#!/usr/bin/env python
|
||||||
# -*- encoding: utf-8 -*-
|
# -*- encoding: utf-8 -*-
|
||||||
|
|
||||||
import time
|
import time
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from colossalai import initialize
|
import colossalai
|
||||||
from colossalai.context import ParallelMode
|
from colossalai.context import ParallelMode
|
||||||
from colossalai.core import global_context as gpc
|
from colossalai.core import global_context as gpc
|
||||||
from colossalai.engine import Engine
|
|
||||||
from colossalai.logging import get_global_dist_logger
|
from colossalai.logging import get_global_dist_logger
|
||||||
from colossalai.trainer import Trainer
|
from colossalai.trainer import Trainer
|
||||||
from colossalai.trainer.metric import Accuracy3D
|
from colossalai.trainer.metric import Accuracy3D
|
||||||
|
@ -29,7 +27,7 @@ def _train_epoch(epoch, engine):
|
||||||
num_samples = 0
|
num_samples = 0
|
||||||
now = time.time()
|
now = time.time()
|
||||||
epoch_start = now
|
epoch_start = now
|
||||||
progress = range(engine.schedule.num_steps)
|
progress = range(engine._schedule.num_steps)
|
||||||
if gpc.get_global_rank() == 0:
|
if gpc.get_global_rank() == 0:
|
||||||
progress = tqdm(progress, desc='[Epoch %d]' % epoch, miniters=1)
|
progress = tqdm(progress, desc='[Epoch %d]' % epoch, miniters=1)
|
||||||
for step in progress:
|
for step in progress:
|
||||||
|
@ -68,7 +66,7 @@ def _eval(epoch, engine):
|
||||||
ParallelMode.PARALLEL_3D_WEIGHT)
|
ParallelMode.PARALLEL_3D_WEIGHT)
|
||||||
total = 0
|
total = 0
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for _ in range(engine.schedule.num_steps):
|
for _ in range(engine._schedule.num_steps):
|
||||||
outputs, targets, loss = engine.step()
|
outputs, targets, loss = engine.step()
|
||||||
if isinstance(outputs, (list, tuple)):
|
if isinstance(outputs, (list, tuple)):
|
||||||
outputs = outputs[0]
|
outputs = outputs[0]
|
||||||
|
@ -80,32 +78,25 @@ def _eval(epoch, engine):
|
||||||
|
|
||||||
print_rank_0(
|
print_rank_0(
|
||||||
'[Epoch %d] Evaluation loss: %.3f | Acc: %.3f%%' %
|
'[Epoch %d] Evaluation loss: %.3f | Acc: %.3f%%' %
|
||||||
(epoch, eval_loss / engine.schedule.num_steps,
|
(epoch, eval_loss / engine._schedule.num_steps,
|
||||||
acc.get_accumulated_value() * 100), logger)
|
acc.get_accumulated_value() * 100), logger)
|
||||||
|
|
||||||
|
|
||||||
def train():
|
def train():
|
||||||
model, train_dataloader, test_dataloader, criterion, \
|
# init dist
|
||||||
optimizer, schedule, lr_scheduler = initialize(CONFIG_PATH)
|
engine, train_dataloader, test_dataloader = colossalai.initialize(CONFIG_PATH)
|
||||||
|
|
||||||
logger = get_global_dist_logger()
|
logger = get_global_dist_logger()
|
||||||
|
|
||||||
engine = Engine(model=model,
|
|
||||||
train_dataloader=train_dataloader,
|
|
||||||
test_dataloader=test_dataloader,
|
|
||||||
criterion=criterion,
|
|
||||||
optimizer=optimizer,
|
|
||||||
lr_scheduler=lr_scheduler,
|
|
||||||
schedule=schedule)
|
|
||||||
logger.info("Engine is built", ranks=[0])
|
logger.info("Engine is built", ranks=[0])
|
||||||
|
|
||||||
trainer = Trainer(engine=engine, hooks_cfg=gpc.config.hooks, verbose=True)
|
trainer = Trainer(engine=engine, verbose=True)
|
||||||
logger.info("Trainer is built", ranks=[0])
|
logger.info("Trainer is built", ranks=[0])
|
||||||
|
|
||||||
logger.info("Train start", ranks=[0])
|
logger.info("Train start", ranks=[0])
|
||||||
trainer.fit(train_dataloader=train_dataloader,
|
trainer.fit(train_dataloader=train_dataloader,
|
||||||
test_dataloader=test_dataloader,
|
test_dataloader=test_dataloader,
|
||||||
max_epochs=gpc.config.num_epochs,
|
epochs=gpc.config.num_epochs,
|
||||||
|
hooks_cfg=gpc.config.hooks,
|
||||||
display_progress=True,
|
display_progress=True,
|
||||||
test_interval=1)
|
test_interval=1)
|
||||||
|
|
||||||
|
|
|
@ -3,6 +3,7 @@ from pathlib import Path
|
||||||
|
|
||||||
BATCH_SIZE = 128
|
BATCH_SIZE = 128
|
||||||
IMG_SIZE = 32
|
IMG_SIZE = 32
|
||||||
|
num_epochs = 200
|
||||||
|
|
||||||
# resnet 50
|
# resnet 50
|
||||||
model = dict(
|
model = dict(
|
||||||
|
@ -77,18 +78,14 @@ hooks = [
|
||||||
dict(type='AccuracyHook'),
|
dict(type='AccuracyHook'),
|
||||||
dict(type='LossHook'),
|
dict(type='LossHook'),
|
||||||
dict(type='TensorboardHook', log_dir='./tfb_logs'),
|
dict(type='TensorboardHook', log_dir='./tfb_logs'),
|
||||||
|
dict(
|
||||||
|
type='LRSchedulerHook',
|
||||||
|
by_epoch=True,
|
||||||
|
lr_scheduler_cfg=dict(
|
||||||
|
type='CosineAnnealingLR',
|
||||||
|
warmup_steps=5
|
||||||
|
)
|
||||||
|
),
|
||||||
dict(type='SaveCheckpointHook', interval=5, checkpoint_dir='./ckpt'),
|
dict(type='SaveCheckpointHook', interval=5, checkpoint_dir='./ckpt'),
|
||||||
# dict(type='LoadCheckpointHook', epoch=20, checkpoint_dir='./ckpt')
|
|
||||||
]
|
]
|
||||||
|
|
||||||
# fp16 = dict(
|
|
||||||
# mode=AMP_TYPE.PARALLEL,
|
|
||||||
# initial_scale=1
|
|
||||||
# )
|
|
||||||
|
|
||||||
lr_scheduler = dict(
|
|
||||||
type='CosineAnnealingLR',
|
|
||||||
T_max=200
|
|
||||||
)
|
|
||||||
|
|
||||||
num_epochs = 200
|
|
||||||
|
|
|
@ -11,6 +11,7 @@ NUM_ATTENTION_HEADS = 8
|
||||||
SUMMA_DIM = 2
|
SUMMA_DIM = 2
|
||||||
NUM_CLASSES = 10
|
NUM_CLASSES = 10
|
||||||
DEPTH = 6
|
DEPTH = 6
|
||||||
|
num_epochs = 60
|
||||||
|
|
||||||
train_data = dict(
|
train_data = dict(
|
||||||
dataset=dict(type='CIFAR10Dataset',
|
dataset=dict(type='CIFAR10Dataset',
|
||||||
|
@ -52,13 +53,6 @@ optimizer = dict(type='Adam', lr=0.001, weight_decay=0)
|
||||||
|
|
||||||
loss = dict(type='CrossEntropyLoss2D', )
|
loss = dict(type='CrossEntropyLoss2D', )
|
||||||
|
|
||||||
# model = dict(
|
|
||||||
# type='VanillaResNet',
|
|
||||||
# block_type='ResNetBasicBlock',
|
|
||||||
# layers=[2, 2, 2, 2],
|
|
||||||
# num_cls=10
|
|
||||||
# )
|
|
||||||
|
|
||||||
model = dict(
|
model = dict(
|
||||||
type='VisionTransformerFromConfig',
|
type='VisionTransformerFromConfig',
|
||||||
tensor_splitting_cfg=dict(type='ViTInputSplitter2D', ),
|
tensor_splitting_cfg=dict(type='ViTInputSplitter2D', ),
|
||||||
|
@ -114,8 +108,15 @@ hooks = [
|
||||||
dict(type='Accuracy2DHook'),
|
dict(type='Accuracy2DHook'),
|
||||||
dict(type='LossHook'),
|
dict(type='LossHook'),
|
||||||
dict(type='TensorboardHook', log_dir='./tfb_logs'),
|
dict(type='TensorboardHook', log_dir='./tfb_logs'),
|
||||||
|
dict(
|
||||||
|
type='LRSchedulerHook',
|
||||||
|
by_epoch=True,
|
||||||
|
lr_scheduler_cfg=dict(
|
||||||
|
type='LinearWarmupLR',
|
||||||
|
warmup_steps=5
|
||||||
|
)
|
||||||
|
),
|
||||||
dict(type='SaveCheckpointHook', interval=5, checkpoint_dir='./ckpt'),
|
dict(type='SaveCheckpointHook', interval=5, checkpoint_dir='./ckpt'),
|
||||||
# dict(type='LoadCheckpointHook', epoch=20, checkpoint_dir='./ckpt')
|
|
||||||
]
|
]
|
||||||
|
|
||||||
parallel = dict(
|
parallel = dict(
|
||||||
|
@ -125,11 +126,8 @@ parallel = dict(
|
||||||
|
|
||||||
fp16 = dict(mode=AMP_TYPE.PARALLEL, initial_scale=2 ** 8)
|
fp16 = dict(mode=AMP_TYPE.PARALLEL, initial_scale=2 ** 8)
|
||||||
|
|
||||||
lr_scheduler = dict(type='LinearWarmupLR', warmup_epochs=5)
|
engine = dict(
|
||||||
|
schedule=dict(num_microbatches=1)
|
||||||
schedule = dict(num_microbatches=1)
|
)
|
||||||
|
|
||||||
num_epochs = 60
|
|
||||||
num_microbatches = 1
|
|
||||||
|
|
||||||
logging = dict(root_path='./logs')
|
logging = dict(root_path='./logs')
|
||||||
|
|
|
@ -1,25 +1,16 @@
|
||||||
import colossalai
|
import colossalai
|
||||||
from colossalai.core import global_context as gpc
|
from colossalai.core import global_context as gpc
|
||||||
from colossalai.engine import Engine
|
|
||||||
from colossalai.logging import get_global_dist_logger
|
from colossalai.logging import get_global_dist_logger
|
||||||
from colossalai.trainer import Trainer
|
from colossalai.trainer import Trainer
|
||||||
|
|
||||||
|
|
||||||
def test_trainer():
|
def test_trainer():
|
||||||
model, train_dataloader, test_dataloader, criterion, optimizer, schedule, lr_scheduler = colossalai.initialize()
|
engine, train_dataloader, test_dataloader = colossalai.initialize()
|
||||||
logger = get_global_dist_logger()
|
logger = get_global_dist_logger()
|
||||||
|
|
||||||
engine = Engine(
|
|
||||||
model=model,
|
|
||||||
criterion=criterion,
|
|
||||||
optimizer=optimizer,
|
|
||||||
lr_scheduler=lr_scheduler,
|
|
||||||
schedule=schedule
|
|
||||||
)
|
|
||||||
logger.info("engine is built", ranks=[0])
|
logger.info("engine is built", ranks=[0])
|
||||||
|
|
||||||
trainer = Trainer(engine=engine,
|
trainer = Trainer(engine=engine,
|
||||||
hooks_cfg=gpc.config.hooks,
|
|
||||||
verbose=True)
|
verbose=True)
|
||||||
logger.info("trainer is built", ranks=[0])
|
logger.info("trainer is built", ranks=[0])
|
||||||
|
|
||||||
|
@ -27,7 +18,8 @@ def test_trainer():
|
||||||
trainer.fit(
|
trainer.fit(
|
||||||
train_dataloader=train_dataloader,
|
train_dataloader=train_dataloader,
|
||||||
test_dataloader=test_dataloader,
|
test_dataloader=test_dataloader,
|
||||||
max_epochs=gpc.config.num_epochs,
|
hooks_cfg=gpc.config.hooks,
|
||||||
|
epochs=gpc.config.num_epochs,
|
||||||
display_progress=False,
|
display_progress=False,
|
||||||
test_interval=5
|
test_interval=5
|
||||||
)
|
)
|
||||||
|
|
|
@ -18,14 +18,16 @@ level = os.environ['LEVEL']
|
||||||
CONFIG_PATH = Path(__file__).parent.parent.joinpath(f'configs/vit_2d_zero{level}.py')
|
CONFIG_PATH = Path(__file__).parent.parent.joinpath(f'configs/vit_2d_zero{level}.py')
|
||||||
|
|
||||||
|
|
||||||
def eval(engine):
|
def eval_epoch(engine: Engine, test_dataloader):
|
||||||
engine.eval()
|
engine.eval()
|
||||||
accumulated_loss = 0
|
accumulated_loss = 0
|
||||||
correct_sum = 0
|
correct_sum = 0
|
||||||
total_sum = 0
|
total_sum = 0
|
||||||
|
num_steps = len(test_dataloader)
|
||||||
|
data_iter = iter(test_dataloader)
|
||||||
|
|
||||||
for i in range(engine.schedule.num_steps):
|
for i in range(num_steps):
|
||||||
output, label, loss = engine.step()
|
output, label, loss = engine.step(data_iter)
|
||||||
accumulated_loss += loss.detach().cpu().numpy()
|
accumulated_loss += loss.detach().cpu().numpy()
|
||||||
|
|
||||||
output = _gather(
|
output = _gather(
|
||||||
|
@ -42,18 +44,19 @@ def eval(engine):
|
||||||
correct = torch.sum(label[0] == output)
|
correct = torch.sum(label[0] == output)
|
||||||
correct_sum += correct
|
correct_sum += correct
|
||||||
total_sum += label[0].size(0)
|
total_sum += label[0].size(0)
|
||||||
avg_loss = accumulated_loss / engine.schedule.num_steps
|
avg_loss = accumulated_loss / num_steps
|
||||||
return correct_sum, total_sum, avg_loss
|
return correct_sum, total_sum, avg_loss
|
||||||
|
|
||||||
|
|
||||||
def train(engine):
|
def train_epoch(engine, train_dataloader):
|
||||||
engine.train()
|
engine.train()
|
||||||
accumulated_loss = 0
|
accumulated_loss = 0
|
||||||
|
num_steps = len(train_dataloader)
|
||||||
for i in range(engine.schedule.num_steps):
|
data_iter = iter(train_dataloader)
|
||||||
output, label, loss = engine.step()
|
for i in range(num_steps):
|
||||||
|
output, label, loss = engine.step(data_iter)
|
||||||
accumulated_loss += loss.detach().cpu().numpy()
|
accumulated_loss += loss.detach().cpu().numpy()
|
||||||
avg_loss = accumulated_loss / engine.schedule.num_steps
|
avg_loss = accumulated_loss / num_steps
|
||||||
return avg_loss
|
return avg_loss
|
||||||
|
|
||||||
|
|
||||||
|
@ -61,30 +64,17 @@ def train(engine):
|
||||||
@pytest.mark.skip("This test should be invoked by test.sh in the same folder as it runs on multiple gpus")
|
@pytest.mark.skip("This test should be invoked by test.sh in the same folder as it runs on multiple gpus")
|
||||||
def test_2d_parallel_vision_transformer():
|
def test_2d_parallel_vision_transformer():
|
||||||
# init dist
|
# init dist
|
||||||
model, train_dataloader, test_dataloader, criterion, optimizer, schedule, lr_scheduler = colossalai.initialize(
|
engine, train_dataloader, test_dataloader = colossalai.initialize(CONFIG_PATH)
|
||||||
CONFIG_PATH)
|
|
||||||
logger = get_global_dist_logger()
|
logger = get_global_dist_logger()
|
||||||
|
|
||||||
engine = Engine(model=model,
|
|
||||||
train_dataloader=train_dataloader,
|
|
||||||
test_dataloader=test_dataloader,
|
|
||||||
criterion=criterion,
|
|
||||||
optimizer=optimizer,
|
|
||||||
lr_scheduler=lr_scheduler,
|
|
||||||
schedule=schedule)
|
|
||||||
|
|
||||||
# for param in model.parameters():
|
|
||||||
# if isinstance(param, torch.HalfTensor):
|
|
||||||
# print(param.shape)
|
|
||||||
|
|
||||||
logger.info('start training')
|
logger.info('start training')
|
||||||
for epoch in range(gpc.config.num_epochs):
|
for epoch in range(gpc.config.num_epochs):
|
||||||
train_loss = train(engine)
|
train_loss = train_epoch(engine, train_dataloader)
|
||||||
|
|
||||||
logger.info(f'epoch {epoch} - train loss: {train_loss}')
|
logger.info(f'epoch {epoch} - train loss: {train_loss}')
|
||||||
|
|
||||||
if epoch % 2 == 0:
|
if epoch % 2 == 0:
|
||||||
correct_sum, total_sum, eval_loss = eval(engine)
|
correct_sum, total_sum, eval_loss = eval_epoch(engine, test_dataloader)
|
||||||
logger.info(
|
logger.info(
|
||||||
f'epoch {epoch} - eval loss: {eval_loss}, total: {total_sum}, '
|
f'epoch {epoch} - eval loss: {eval_loss}, total: {total_sum}, '
|
||||||
f'correct: {correct_sum}, acc: {correct_sum / total_sum}')
|
f'correct: {correct_sum}, acc: {correct_sum / total_sum}')
|
||||||
|
|
Loading…
Reference in New Issue