mirror of https://github.com/hpcaitech/ColossalAI
Support TP-compatible Torch AMP and Update trainer API (#27)
* Add gradient accumulation, fix lr scheduler
* fix FP16 optimizer and adapted torch amp with tensor parallel (#18)
* fixed bugs in compatibility between torch amp and tensor parallel and performed some minor fixes
* fixed trainer
* Revert "fixed trainer"
This reverts commit 2e0b0b7699
.
* improved consistency between trainer, engine and schedule (#23)
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: ver217 <lhx0217@gmail.com>
pull/28/head
parent
2b05de4c64
commit
3defa32aee
14
README.md
14
README.md
|
@ -42,26 +42,18 @@ pip install -v --no-cache-dir --global-option="--cuda_ext" .
|
||||||
|
|
||||||
```python
|
```python
|
||||||
import colossalai
|
import colossalai
|
||||||
from colossalai.engine import Engine
|
|
||||||
from colossalai.trainer import Trainer
|
from colossalai.trainer import Trainer
|
||||||
from colossalai.core import global_context as gpc
|
from colossalai.core import global_context as gpc
|
||||||
|
|
||||||
model, train_dataloader, test_dataloader, criterion, optimizer, schedule, lr_scheduler = colossalai.initialize()
|
engine, train_dataloader, test_dataloader = colossalai.initialize()
|
||||||
engine = Engine(
|
|
||||||
model=model,
|
|
||||||
criterion=criterion,
|
|
||||||
optimizer=optimizer,
|
|
||||||
lr_scheduler=lr_scheduler,
|
|
||||||
schedule=schedule
|
|
||||||
)
|
|
||||||
|
|
||||||
trainer = Trainer(engine=engine,
|
trainer = Trainer(engine=engine,
|
||||||
hooks_cfg=gpc.config.hooks,
|
|
||||||
verbose=True)
|
verbose=True)
|
||||||
trainer.fit(
|
trainer.fit(
|
||||||
train_dataloader=train_dataloader,
|
train_dataloader=train_dataloader,
|
||||||
test_dataloader=test_dataloader,
|
test_dataloader=test_dataloader,
|
||||||
max_epochs=gpc.config.num_epochs,
|
epochs=gpc.config.num_epochs,
|
||||||
|
hooks_cfg=gpc.config.hooks,
|
||||||
display_progress=True,
|
display_progress=True,
|
||||||
test_interval=5
|
test_interval=5
|
||||||
)
|
)
|
||||||
|
|
|
@ -1,2 +1,10 @@
|
||||||
from .builder import *
|
from .builder import (build_schedule, build_lr_scheduler, build_model, build_optimizer, build_optimizer_wrapper,
|
||||||
|
build_layer, build_loss, build_hooks, build_dataset, build_transform, build_data_sampler,
|
||||||
|
build_gradient_handler)
|
||||||
from .pipeline import ModelInitializer
|
from .pipeline import ModelInitializer
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
'build_schedule', 'build_lr_scheduler', 'build_model', 'build_optimizer', 'build_optimizer_wrapper',
|
||||||
|
'build_layer', 'build_loss', 'build_hooks', 'build_dataset', 'build_transform', 'build_data_sampler',
|
||||||
|
'build_gradient_handler', 'ModelInitializer'
|
||||||
|
]
|
||||||
|
|
|
@ -181,18 +181,6 @@ def build_transform(config):
|
||||||
return build_from_registry(config, TRANSFORMS)
|
return build_from_registry(config, TRANSFORMS)
|
||||||
|
|
||||||
|
|
||||||
def build_pipe_alloc_policy(config):
|
|
||||||
"""Returns a pipeline allocation policy object constructed from `config`.
|
|
||||||
|
|
||||||
:param config: A python dict or a :class:`colossalai.context.Config` object
|
|
||||||
containing information used in the construction of the return object
|
|
||||||
:type config: dict or :class:`colossalai.context.Config`
|
|
||||||
:return: A pipeline allocation policy object
|
|
||||||
:rtype:
|
|
||||||
"""
|
|
||||||
return build_from_registry(config, PIPE_ALLOC_POLICY)
|
|
||||||
|
|
||||||
|
|
||||||
def build_data_sampler(config, dataset):
|
def build_data_sampler(config, dataset):
|
||||||
"""Returns a data sampler object of :class:`colossalai.nn.data.sampler.BaseSampler`
|
"""Returns a data sampler object of :class:`colossalai.nn.data.sampler.BaseSampler`
|
||||||
constructed from `config`.
|
constructed from `config`.
|
||||||
|
@ -235,7 +223,7 @@ def build_optimizer_wrapper(config, optimizer, model=None):
|
||||||
return OPTIMIZER_WRAPPERS.get_module(mod_type)(optimizer, **config_)
|
return OPTIMIZER_WRAPPERS.get_module(mod_type)(optimizer, **config_)
|
||||||
|
|
||||||
|
|
||||||
def build_lr_scheduler(config, optimizer, total_steps, num_steps_per_epoch):
|
def build_lr_scheduler(config, optimizer):
|
||||||
"""Returns a learning rate scheduler object of :class:`torch.optim.lr_scheduler`
|
"""Returns a learning rate scheduler object of :class:`torch.optim.lr_scheduler`
|
||||||
constructed from `config`, `optimizer`, `total_steps` and `num_steps_per_epoch`.
|
constructed from `config`, `optimizer`, `total_steps` and `num_steps_per_epoch`.
|
||||||
|
|
||||||
|
@ -254,9 +242,16 @@ def build_lr_scheduler(config, optimizer, total_steps, num_steps_per_epoch):
|
||||||
"""
|
"""
|
||||||
config_ = config.copy()
|
config_ = config.copy()
|
||||||
mod_type = config_.pop('type')
|
mod_type = config_.pop('type')
|
||||||
# warmup epochs will overwrite warmup steps
|
return LR_SCHEDULERS.get_module(mod_type)(optimizer, **config_)
|
||||||
if 'warmup_epochs' in config_:
|
|
||||||
warmup_epochs = config_.pop('warmup_epochs')
|
|
||||||
config_['warmup_steps'] = int(num_steps_per_epoch * warmup_epochs)
|
def build_schedule(config):
|
||||||
return LR_SCHEDULERS.get_module(mod_type)(optimizer, total_steps, num_steps_per_epoch=num_steps_per_epoch,
|
"""Returns a schedule of :class:`colossalai.engine.schedule.BaseSchedule`.
|
||||||
**config_)
|
|
||||||
|
:param config: A python dict or a :class:`colossalai.context.Config` object
|
||||||
|
containing information used in the construction of the return object
|
||||||
|
:type config: dict or :class:`colossalai.context.Config`
|
||||||
|
:return: An object of :class:`colossalai.engine.schedule.BaseSchedule`
|
||||||
|
:rtype: :class:`colossalai.engine.schedule.BaseSchedule`
|
||||||
|
"""
|
||||||
|
return build_from_registry(config, SCHEDULE)
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
from .amp_type import AMP_TYPE
|
|
||||||
from ._base_engine import Engine
|
from ._base_engine import Engine
|
||||||
from .gradient_handler import *
|
from .gradient_handler import *
|
||||||
from .schedule import *
|
from .schedule import *
|
||||||
|
from .amp import *
|
||||||
|
|
||||||
|
|
||||||
__all__ = ['Engine']
|
__all__ = ['Engine']
|
||||||
|
|
|
@ -1,7 +1,9 @@
|
||||||
#!/usr/bin/env python
|
#!/usr/bin/env python
|
||||||
# -*- encoding: utf-8 -*-
|
# -*- encoding: utf-8 -*-
|
||||||
|
|
||||||
from typing import Optional
|
from torch.nn import Module
|
||||||
|
from torch.nn.modules.loss import _Loss
|
||||||
|
from torch.optim import Optimizer
|
||||||
|
|
||||||
from colossalai.builder import build_gradient_handler
|
from colossalai.builder import build_gradient_handler
|
||||||
from colossalai.context import ParallelMode
|
from colossalai.context import ParallelMode
|
||||||
|
@ -9,89 +11,103 @@ from colossalai.core import global_context as gpc
|
||||||
from colossalai.logging import get_global_dist_logger
|
from colossalai.logging import get_global_dist_logger
|
||||||
from colossalai.nn import (ZeroRedundancyOptimizer_Level_2,
|
from colossalai.nn import (ZeroRedundancyOptimizer_Level_2,
|
||||||
ZeroRedundancyOptimizer_Level_3)
|
ZeroRedundancyOptimizer_Level_3)
|
||||||
from torch.nn import Module
|
from .schedule import BaseSchedule
|
||||||
from torch.nn.modules.loss import _Loss
|
|
||||||
from torch.optim import Optimizer
|
|
||||||
from torch.optim.lr_scheduler import _LRScheduler
|
|
||||||
from torch.utils.data import DataLoader
|
|
||||||
|
|
||||||
from .schedule import BaseSchedule, NoPipelineSchedule
|
|
||||||
|
|
||||||
|
|
||||||
class Engine:
|
class Engine:
|
||||||
"""Basic engine class for training and evaluation. It runs a specific process method
|
"""Basic engine class for training and evaluation. It runs a specific process method
|
||||||
:meth:`step` which is based on the given :attr:`schedule` over each batch of a dataset.
|
:meth:`step` which is based on the given :attr:`schedule` over each batch of a dataset.
|
||||||
|
It controls a iteration in training.
|
||||||
|
|
||||||
:param train_dataloader: Dataloader in training
|
|
||||||
:param test_dataloader: Dataloader in evaluation
|
|
||||||
:param model: The neural network model
|
:param model: The neural network model
|
||||||
:param criterion: Criterion for calculating loss
|
|
||||||
:param optimizer: Optimizer for updating the parameters
|
:param optimizer: Optimizer for updating the parameters
|
||||||
:param lr_scheduler: Learning rate scheduler ajusting learning rate during the training or evaluation
|
:param step_schedule: Running schedule in :meth:`step`
|
||||||
:param schedule: Running schedule in :meth:`step`
|
:param gradient_accumulation: Steps of gradient accumulation
|
||||||
:type train_dataloader: DataLoader, optional
|
:param gradient_clipping: The norm of gradient clipping
|
||||||
:type test_dataloader: DataLoader, optional
|
|
||||||
:type model: Module
|
:type model: Module
|
||||||
:type criterion: _Loss, optional
|
:type optimizer: Optimizer
|
||||||
:type optimizer: Optimizer, optional
|
:type step_schedule: BaseSchedule, optional
|
||||||
:type lr_scheduler: _LRScheduler, optional
|
:type gradient_accumulation: int, optional
|
||||||
:type schedule: BaseSchedule, optional
|
:type gradient_clipping: float, optional
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
train_dataloader: Optional[DataLoader] = None,
|
model: Module,
|
||||||
test_dataloader: Optional[DataLoader] = None,
|
optimizer: Optimizer,
|
||||||
model: Module = None,
|
criterion: _Loss,
|
||||||
criterion: _Loss = None,
|
step_schedule: BaseSchedule,
|
||||||
optimizer: Optimizer = None,
|
gradient_handlers: list = None,
|
||||||
lr_scheduler: Optional[_LRScheduler] = None,
|
gradient_accumulation: int = 1,
|
||||||
schedule: BaseSchedule = None):
|
gradient_clipping: float = 0.0,
|
||||||
self.train_dataloader = train_dataloader
|
):
|
||||||
self.test_dataloader = test_dataloader
|
self._model = model
|
||||||
assert model is not None, "Engine requires a model"
|
self._optimizer = optimizer
|
||||||
self.model = model
|
self._criterion = criterion
|
||||||
self.criterion = criterion
|
self._schedule = step_schedule
|
||||||
self.optimizer = optimizer
|
|
||||||
self.lr_scheduler = lr_scheduler
|
# schedule initialize
|
||||||
self.schedule = schedule if schedule is not None \
|
self._schedule.initialize(model, optimizer)
|
||||||
else NoPipelineSchedule()
|
|
||||||
|
# state
|
||||||
|
self.training = True # default
|
||||||
|
|
||||||
|
# gradient accumulation
|
||||||
|
assert gradient_accumulation > 0, 'gradient accumulation size must be larger than 0'
|
||||||
|
self._grad_accum_size = gradient_accumulation
|
||||||
|
self._grad_clip = gradient_clipping
|
||||||
self._logger = get_global_dist_logger()
|
self._logger = get_global_dist_logger()
|
||||||
|
|
||||||
# build gradient handler
|
# build gradient handler
|
||||||
self._gradient_handlers = []
|
self._gradient_handlers = []
|
||||||
gradient_handler_cfg = []
|
|
||||||
|
|
||||||
if hasattr(gpc.config, 'gradient_handler'):
|
if gradient_handlers is not None:
|
||||||
assert isinstance(gpc.config.gradient_handler, list), \
|
assert isinstance(gradient_handlers, list), \
|
||||||
f'argument gradient_handler_cfg expected type list, ' \
|
f'argument gradient_handler_cfg expected type list, ' \
|
||||||
f'but got type {type(gpc.config.gradient_handler)}'
|
f'but got type {type(gradient_handlers)}'
|
||||||
gradient_handler_cfg = gpc.config.gradient_handler
|
elif isinstance(optimizer, (ZeroRedundancyOptimizer_Level_2,
|
||||||
elif isinstance(self.optimizer, (ZeroRedundancyOptimizer_Level_2,
|
|
||||||
ZeroRedundancyOptimizer_Level_3)):
|
ZeroRedundancyOptimizer_Level_3)):
|
||||||
gradient_handler_cfg = [dict(type='ZeROGradientHandler')]
|
gradient_handlers = [dict(type='ZeROGradientHandler')]
|
||||||
self._logger.info(
|
self._logger.info(
|
||||||
"Training with zero is detected, ZeROGradientHandler is automatically "
|
"Training with zero is detected, ZeROGradientHandler is automatically "
|
||||||
"added even though not specified in the configuration",
|
"added even though not specified in the configuration",
|
||||||
ranks=[0])
|
ranks=[0])
|
||||||
elif gpc.is_initialized(ParallelMode.DATA) and gpc.get_world_size(
|
elif gpc.is_initialized(ParallelMode.DATA) and gpc.get_world_size(
|
||||||
ParallelMode.DATA) > 1:
|
ParallelMode.DATA) > 1:
|
||||||
gradient_handler_cfg = [dict(type='DataParallelGradientHandler')]
|
gradient_handlers = [dict(type='DataParallelGradientHandler')]
|
||||||
self._logger.info(
|
self._logger.info(
|
||||||
"Data parallel training is detected, DataParallelGradientHandler is automatically "
|
"Data parallel training is detected, DataParallelGradientHandler is automatically "
|
||||||
"added even though not specified in the configuration",
|
"added even though not specified in the configuration",
|
||||||
ranks=[0])
|
ranks=[0])
|
||||||
if len(gradient_handler_cfg) == 0:
|
|
||||||
|
if gradient_handlers is None:
|
||||||
self._logger.warning(
|
self._logger.warning(
|
||||||
"No gradient handler is set up, please make sure you do not need "
|
"No gradient handler is set up, please make sure you do not need "
|
||||||
"to all-reduce the gradients after a training step.",
|
"to all-reduce the gradients after a training step.",
|
||||||
ranks=[0])
|
ranks=[0])
|
||||||
for cfg in gradient_handler_cfg:
|
else:
|
||||||
handler = build_gradient_handler(cfg, self.model, self.optimizer)
|
for cfg in gradient_handlers:
|
||||||
|
handler = build_gradient_handler(cfg, model, optimizer)
|
||||||
self._gradient_handlers.append(handler)
|
self._gradient_handlers.append(handler)
|
||||||
|
|
||||||
self.schedule.initialize(self.train_dataloader, self.model,
|
@property
|
||||||
self.criterion, self.optimizer,
|
def model(self):
|
||||||
self.lr_scheduler)
|
return self._model
|
||||||
self.forward_only = False
|
|
||||||
|
@property
|
||||||
|
def optimizer(self):
|
||||||
|
return self._optimizer
|
||||||
|
|
||||||
|
@property
|
||||||
|
def criterion(self):
|
||||||
|
return self._criterion
|
||||||
|
|
||||||
|
@property
|
||||||
|
def schedule(self):
|
||||||
|
return self._schedule
|
||||||
|
|
||||||
|
@property
|
||||||
|
def gradient_accumulation(self):
|
||||||
|
return self._grad_accum_size
|
||||||
|
|
||||||
def handle_gradient(self):
|
def handle_gradient(self):
|
||||||
"""Handles all-reduce operations of gradients across different parallel groups.
|
"""Handles all-reduce operations of gradients across different parallel groups.
|
||||||
|
@ -99,72 +115,62 @@ class Engine:
|
||||||
for handler in self._gradient_handlers:
|
for handler in self._gradient_handlers:
|
||||||
handler.handle_gradient()
|
handler.handle_gradient()
|
||||||
|
|
||||||
def set_dataloader(self, data: DataLoader, train: bool = True):
|
|
||||||
"""Sets dataloader in training or evaluation.
|
|
||||||
|
|
||||||
:param data: Dataloader to be set
|
|
||||||
:param train: Set training dataloader if True, otherwise evaluation dataloader
|
|
||||||
:type data: DataLoader
|
|
||||||
:type train: bool
|
|
||||||
"""
|
|
||||||
if train:
|
|
||||||
self.train_dataloader = data
|
|
||||||
else:
|
|
||||||
self.test_dataloader = data
|
|
||||||
|
|
||||||
def get_model(self):
|
|
||||||
"""Returns the neural network model in the engine.
|
|
||||||
"""
|
|
||||||
return self.model
|
|
||||||
def get_optimizer(self):
|
|
||||||
"""Returns optimizier in the engine.
|
|
||||||
"""
|
|
||||||
return self.optimizer
|
|
||||||
|
|
||||||
def get_lr_scheduler(self):
|
|
||||||
"""Returns the learning rate scheduler in the engine.
|
|
||||||
"""
|
|
||||||
return self.lr_scheduler
|
|
||||||
|
|
||||||
def train(self):
|
def train(self):
|
||||||
"""Sets the model to training mode.
|
"""Sets the model to training mode.
|
||||||
"""
|
"""
|
||||||
self.forward_only = False
|
self.training = True
|
||||||
self.schedule.train(dataloader=self.train_dataloader, mode=True)
|
self._model.train()
|
||||||
|
|
||||||
def eval(self):
|
def eval(self):
|
||||||
"""Sets the model to evaluation mode.
|
"""Sets the model to evaluation mode.
|
||||||
"""
|
"""
|
||||||
self.forward_only = True
|
self.training = False
|
||||||
self.schedule.train(dataloader=self.test_dataloader, mode=False)
|
self._model.eval()
|
||||||
|
|
||||||
def is_train(self):
|
def step(self,
|
||||||
"""Returns True if it is in training, otherwise False.
|
data_iter,
|
||||||
"""
|
is_last_iteration: bool = False,
|
||||||
return not self.forward_only
|
return_loss=True):
|
||||||
|
|
||||||
def get_lr(self):
|
|
||||||
"""Gets current learning rate.
|
|
||||||
"""
|
|
||||||
return self.schedule.get_lr()
|
|
||||||
|
|
||||||
def step(self, return_loss=True):
|
|
||||||
"""A running step based on the schedule. Usually, it runs a training or
|
"""A running step based on the schedule. Usually, it runs a training or
|
||||||
evaluation over a batch of dataset.
|
evaluation over a batch of dataset.
|
||||||
|
|
||||||
|
:param data_iter: Data iterator of the dataset
|
||||||
|
:param is_last_iteration: If True, this iteration is the last iteration in the epoch
|
||||||
:param return_loss: loss will be returned if True
|
:param return_loss: loss will be returned if True
|
||||||
:type return_loss: bool
|
:type data_iter: Iterator
|
||||||
|
:type is_last_iteration: bool, optional
|
||||||
|
:type return_loss: bool, optional
|
||||||
:return: (output, lablel, loss)
|
:return: (output, lablel, loss)
|
||||||
"""
|
"""
|
||||||
self.schedule.zero_grad(forward_only=self.forward_only)
|
if self.training:
|
||||||
|
self._optimizer.zero_grad()
|
||||||
|
|
||||||
output, label, loss = self.schedule.forward_backward_step(
|
# differentiate training and eval with grad accum
|
||||||
forward_only=self.forward_only, return_loss=return_loss)
|
if self.training:
|
||||||
|
for i in range(self._grad_accum_size):
|
||||||
|
output, label, loss = self._schedule.forward_backward_step(
|
||||||
|
data_iter, self._model, self._criterion, self._optimizer,
|
||||||
|
forward_only=False,
|
||||||
|
grad_accum_size=self._grad_accum_size,
|
||||||
|
return_loss=return_loss)
|
||||||
|
|
||||||
if not self.forward_only:
|
if i == self._grad_accum_size - 1:
|
||||||
# all reduce gradients
|
# all reduce gradients
|
||||||
self.handle_gradient()
|
self.handle_gradient()
|
||||||
|
self._schedule.optimizer_step(self._model, self._optimizer, self._grad_clip)
|
||||||
|
else:
|
||||||
|
output, label, loss = self._schedule.forward_backward_step(
|
||||||
|
data_iter, self._model, self._criterion, self._optimizer,
|
||||||
|
forward_only=True,
|
||||||
|
grad_accum_size=1,
|
||||||
|
return_loss=return_loss)
|
||||||
|
|
||||||
self.schedule.step()
|
# consume the remaining dataset left out due to gradient accumulation
|
||||||
|
if is_last_iteration:
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
_ = next(data_iter)
|
||||||
|
except StopIteration:
|
||||||
|
break
|
||||||
|
|
||||||
return output, label, loss
|
return output, label, loss
|
||||||
|
|
|
@ -0,0 +1,2 @@
|
||||||
|
from .grad_scaler import GradScaler
|
||||||
|
from .amp_type import AMP_TYPE
|
|
@ -0,0 +1,577 @@
|
||||||
|
# modified from https://github.com/pytorch/pytorch/blob/master/torch/cuda/amp/grad_scaler.p
|
||||||
|
import torch
|
||||||
|
from collections import defaultdict, abc
|
||||||
|
import warnings
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
from colossalai.context import ParallelMode
|
||||||
|
import torch.distributed as dist
|
||||||
|
from colossalai.core import global_context as gpc
|
||||||
|
|
||||||
|
|
||||||
|
class _MultiDeviceReplicator(object):
|
||||||
|
"""
|
||||||
|
Lazily serves copies of a tensor to requested devices. Copies are cached per-device.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, master_tensor: torch.Tensor) -> None:
|
||||||
|
assert master_tensor.is_cuda or master_tensor.device.type == 'xla'
|
||||||
|
self.master = master_tensor
|
||||||
|
self._per_device_tensors: Dict[torch.device, torch.Tensor] = {}
|
||||||
|
|
||||||
|
def get(self, device) -> torch.Tensor:
|
||||||
|
retval = self._per_device_tensors.get(device, None)
|
||||||
|
if retval is None:
|
||||||
|
retval = self.master.to(
|
||||||
|
device=device, non_blocking=True, copy=True)
|
||||||
|
self._per_device_tensors[device] = retval
|
||||||
|
return retval
|
||||||
|
|
||||||
|
|
||||||
|
# Defines default_factory for GradScaler's _per_optimizer_states defaultdict,
|
||||||
|
# as well as associated "enum" values. Prefers defining these at top level because
|
||||||
|
# - Lambdas can't be pickled, so we don't want to supply a lambda as the factory.
|
||||||
|
# - Defining READY, UNSCALED, STEPPED and _refresh_per_optimizer_state within GradScaler
|
||||||
|
# causes a circular reference, which we'd rather avoid.
|
||||||
|
class OptState(Enum):
|
||||||
|
READY = 0
|
||||||
|
UNSCALED = 1
|
||||||
|
STEPPED = 2
|
||||||
|
|
||||||
|
|
||||||
|
def _refresh_per_optimizer_state():
|
||||||
|
return {"stage": OptState.READY, "found_inf_per_device": {}}
|
||||||
|
|
||||||
|
|
||||||
|
class GradScaler(object):
|
||||||
|
_scale: Optional[torch.Tensor]
|
||||||
|
_grows_tracker: Optional[torch.Tensor]
|
||||||
|
_per_optimizer_states: Dict[int, Dict[str, Any]]
|
||||||
|
"""
|
||||||
|
An instance ``scaler`` of :class:`GradScaler` helps perform the steps of gradient scaling
|
||||||
|
conveniently.
|
||||||
|
|
||||||
|
* ``scaler.scale(loss)`` multiplies a given loss by ``scaler``'s current scale factor.
|
||||||
|
* ``scaler.step(optimizer)`` safely unscales gradients and calls ``optimizer.step()``.
|
||||||
|
* ``scaler.update()`` updates ``scaler``'s scale factor.
|
||||||
|
|
||||||
|
Example::
|
||||||
|
|
||||||
|
# Creates a GradScaler once at the beginning of training.
|
||||||
|
scaler = GradScaler()
|
||||||
|
|
||||||
|
for epoch in epochs:
|
||||||
|
for input, target in data:
|
||||||
|
optimizer.zero_grad()
|
||||||
|
output = model(input)
|
||||||
|
loss = loss_fn(output, target)
|
||||||
|
|
||||||
|
# Scales loss. Calls backward() on scaled loss to create scaled gradients.
|
||||||
|
scaler.scale(loss).backward()
|
||||||
|
|
||||||
|
# scaler.step() first unscales gradients of the optimizer's params.
|
||||||
|
# If gradients don't contain infs/NaNs, optimizer.step() is then called,
|
||||||
|
# otherwise, optimizer.step() is skipped.
|
||||||
|
scaler.step(optimizer)
|
||||||
|
|
||||||
|
# Updates the scale for next iteration.
|
||||||
|
scaler.update()
|
||||||
|
|
||||||
|
See the :ref:`Automatic Mixed Precision examples<amp-examples>` for usage
|
||||||
|
(along with autocasting) in more complex cases like gradient clipping, gradient accumulation, gradient penalty,
|
||||||
|
and multiple losses/optimizers.
|
||||||
|
|
||||||
|
``scaler`` dynamically estimates the scale factor each iteration. To minimize gradient underflow,
|
||||||
|
a large scale factor should be used. However, ``float16`` values can "overflow" (become inf or NaN) if
|
||||||
|
the scale factor is too large. Therefore, the optimal scale factor is the largest factor that can be used
|
||||||
|
without incurring inf or NaN gradient values.
|
||||||
|
``scaler`` approximates the optimal scale factor over time by checking the gradients for infs and NaNs during every
|
||||||
|
``scaler.step(optimizer)`` (or optional separate ``scaler.unscale_(optimizer)``, see :meth:`unscale_`).
|
||||||
|
|
||||||
|
* If infs/NaNs are found, ``scaler.step(optimizer)`` skips the underlying ``optimizer.step()`` (so the params
|
||||||
|
themselves remain uncorrupted) and ``update()`` multiplies the scale by ``backoff_factor``.
|
||||||
|
|
||||||
|
* If no infs/NaNs are found, ``scaler.step(optimizer)`` runs the underlying ``optimizer.step()`` as usual.
|
||||||
|
If ``growth_interval`` unskipped iterations occur consecutively, ``update()`` multiplies the scale by
|
||||||
|
``growth_factor``.
|
||||||
|
|
||||||
|
The scale factor often causes infs/NaNs to appear in gradients for the first few iterations as its
|
||||||
|
value calibrates. ``scaler.step`` will skip the underlying ``optimizer.step()`` for these
|
||||||
|
iterations. After that, step skipping should occur rarely (once every few hundred or thousand iterations).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
init_scale (float, optional, default=2.**16): Initial scale factor.
|
||||||
|
growth_factor (float, optional, default=2.0): Factor by which the scale is multiplied during
|
||||||
|
:meth:`update` if no inf/NaN gradients occur for ``growth_interval`` consecutive iterations.
|
||||||
|
backoff_factor (float, optional, default=0.5): Factor by which the scale is multiplied during
|
||||||
|
:meth:`update` if inf/NaN gradients occur in an iteration.
|
||||||
|
growth_interval (int, optional, default=2000): Number of consecutive iterations without inf/NaN gradients
|
||||||
|
that must occur for the scale to be multiplied by ``growth_factor``.
|
||||||
|
enabled (bool, optional, default=True): If ``False``, disables gradient scaling. :meth:`step` simply
|
||||||
|
invokes the underlying ``optimizer.step()``, and other methods become no-ops.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
init_scale=2.**16,
|
||||||
|
growth_factor=2.0,
|
||||||
|
backoff_factor=0.5,
|
||||||
|
growth_interval=2000,
|
||||||
|
enabled=True):
|
||||||
|
if enabled and not torch.cuda.is_available():
|
||||||
|
warnings.warn(
|
||||||
|
"torch.cuda.amp.GradScaler is enabled, but CUDA is not available. Disabling.")
|
||||||
|
self._enabled = False
|
||||||
|
else:
|
||||||
|
self._enabled = enabled
|
||||||
|
|
||||||
|
if self._enabled:
|
||||||
|
assert growth_factor > 1.0, "The growth factor must be > 1.0."
|
||||||
|
assert backoff_factor < 1.0, "The backoff factor must be < 1.0."
|
||||||
|
|
||||||
|
self._init_scale = init_scale
|
||||||
|
# self._scale will be lazily initialized during the first call to scale()
|
||||||
|
self._scale = None
|
||||||
|
self._growth_factor = growth_factor
|
||||||
|
self._backoff_factor = backoff_factor
|
||||||
|
self._growth_interval = growth_interval
|
||||||
|
self._init_growth_tracker = 0
|
||||||
|
# self._growth_tracker will be lazily initialized during the first call to scale()
|
||||||
|
self._growth_tracker = None
|
||||||
|
self._per_optimizer_states = defaultdict(
|
||||||
|
_refresh_per_optimizer_state)
|
||||||
|
|
||||||
|
def _check_scale_growth_tracker(self, funcname) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
fix = "This may indicate your script did not use scaler.scale(loss or outputs) earlier in the iteration."
|
||||||
|
assert self._scale is not None, "Attempted {} but _scale is None. ".format(
|
||||||
|
funcname) + fix
|
||||||
|
assert self._growth_tracker is not None, "Attempted {} but _growth_tracker is None. ".format(
|
||||||
|
funcname) + fix
|
||||||
|
return (self._scale, self._growth_tracker)
|
||||||
|
|
||||||
|
def _lazy_init_scale_growth_tracker(self, dev):
|
||||||
|
assert self._growth_tracker is None, "_growth_tracker initialized before _scale"
|
||||||
|
self._scale = torch.full(
|
||||||
|
(1,), self._init_scale, dtype=torch.float32, device=dev)
|
||||||
|
self._growth_tracker = torch.full(
|
||||||
|
(1,), self._init_growth_tracker, dtype=torch.int32, device=dev)
|
||||||
|
|
||||||
|
def scale(self, outputs):
|
||||||
|
"""
|
||||||
|
Multiplies ('scales') a tensor or list of tensors by the scale factor.
|
||||||
|
|
||||||
|
Returns scaled outputs. If this instance of :class:`GradScaler` is not enabled, outputs are returned
|
||||||
|
unmodified.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
outputs (Tensor or iterable of Tensors): Outputs to scale.
|
||||||
|
"""
|
||||||
|
if not self._enabled:
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
# Short-circuit for the common case.
|
||||||
|
if isinstance(outputs, torch.Tensor):
|
||||||
|
assert outputs.is_cuda or outputs.device.type == 'xla'
|
||||||
|
if self._scale is None:
|
||||||
|
self._lazy_init_scale_growth_tracker(outputs.device)
|
||||||
|
assert self._scale is not None
|
||||||
|
return outputs * self._scale.to(device=outputs.device, non_blocking=True)
|
||||||
|
|
||||||
|
# Invoke the more complex machinery only if we're treating multiple outputs.
|
||||||
|
# holds a reference that can be overwritten by apply_scale
|
||||||
|
stash: List[_MultiDeviceReplicator] = []
|
||||||
|
|
||||||
|
def apply_scale(val):
|
||||||
|
if isinstance(val, torch.Tensor):
|
||||||
|
assert val.is_cuda or val.device.type == 'xla'
|
||||||
|
if len(stash) == 0:
|
||||||
|
if self._scale is None:
|
||||||
|
self._lazy_init_scale_growth_tracker(val.device)
|
||||||
|
assert self._scale is not None
|
||||||
|
stash.append(_MultiDeviceReplicator(self._scale))
|
||||||
|
return val * stash[0].get(val.device)
|
||||||
|
elif isinstance(val, abc.Iterable):
|
||||||
|
iterable = map(apply_scale, val)
|
||||||
|
if isinstance(val, list) or isinstance(val, tuple):
|
||||||
|
return type(val)(iterable)
|
||||||
|
else:
|
||||||
|
return iterable
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"outputs must be a Tensor or an iterable of Tensors")
|
||||||
|
|
||||||
|
return apply_scale(outputs)
|
||||||
|
|
||||||
|
def _unscale_grads_(self, optimizer, inv_scale, found_inf, allow_fp16):
|
||||||
|
per_device_inv_scale = _MultiDeviceReplicator(inv_scale)
|
||||||
|
per_device_found_inf = _MultiDeviceReplicator(found_inf)
|
||||||
|
|
||||||
|
# To set up _amp_foreach_non_finite_check_and_unscale_, split grads by device and dtype.
|
||||||
|
# There could be hundreds of grads, so we'd like to iterate through them just once.
|
||||||
|
# However, we don't know their devices or dtypes in advance.
|
||||||
|
|
||||||
|
# https://stackoverflow.com/questions/5029934/defaultdict-of-defaultdict
|
||||||
|
# Google says mypy struggles with defaultdicts type annotations.
|
||||||
|
per_device_and_dtype_grads = defaultdict(
|
||||||
|
lambda: defaultdict(list)) # type: ignore[var-annotated]
|
||||||
|
with torch.no_grad():
|
||||||
|
for group in optimizer.param_groups:
|
||||||
|
for param in group["params"]:
|
||||||
|
if param.grad is None:
|
||||||
|
continue
|
||||||
|
if (not allow_fp16) and param.grad.dtype == torch.float16:
|
||||||
|
raise ValueError(
|
||||||
|
"Attempting to unscale FP16 gradients.")
|
||||||
|
if param.grad.is_sparse:
|
||||||
|
# is_coalesced() == False means the sparse grad has values with duplicate indices.
|
||||||
|
# coalesce() deduplicates indices and adds all values that have the same index.
|
||||||
|
# For scaled fp16 values, there's a good chance coalescing will cause overflow,
|
||||||
|
# so we should check the coalesced _values().
|
||||||
|
if param.grad.dtype is torch.float16:
|
||||||
|
param.grad = param.grad.coalesce()
|
||||||
|
to_unscale = param.grad._values()
|
||||||
|
else:
|
||||||
|
to_unscale = param.grad
|
||||||
|
|
||||||
|
# TODO: is there a way to split by device and dtype without appending in the inner loop?
|
||||||
|
per_device_and_dtype_grads[to_unscale.device][to_unscale.dtype].append(
|
||||||
|
to_unscale)
|
||||||
|
|
||||||
|
for device, per_dtype_grads in per_device_and_dtype_grads.items():
|
||||||
|
for grads in per_dtype_grads.values():
|
||||||
|
torch._amp_foreach_non_finite_check_and_unscale_(grads,
|
||||||
|
per_device_found_inf.get(
|
||||||
|
device),
|
||||||
|
per_device_inv_scale.get(device))
|
||||||
|
# For tensor parallel paramters it should be all-reduced over tensor parallel process group
|
||||||
|
if gpc.is_initialized(ParallelMode.TENSOR) and gpc.get_world_size(ParallelMode.TENSOR) > 1:
|
||||||
|
for tensor in per_device_found_inf._per_device_tensors.values():
|
||||||
|
dist.all_reduce(tensor, op=dist.ReduceOp.MAX,
|
||||||
|
group=gpc.get_group(ParallelMode.TENSOR))
|
||||||
|
return per_device_found_inf._per_device_tensors
|
||||||
|
|
||||||
|
def unscale_(self, optimizer):
|
||||||
|
"""
|
||||||
|
Divides ("unscales") the optimizer's gradient tensors by the scale factor.
|
||||||
|
|
||||||
|
:meth:`unscale_` is optional, serving cases where you need to
|
||||||
|
:ref:`modify or inspect gradients<working-with-unscaled-gradients>`
|
||||||
|
between the backward pass(es) and :meth:`step`.
|
||||||
|
If :meth:`unscale_` is not called explicitly, gradients will be unscaled automatically during :meth:`step`.
|
||||||
|
|
||||||
|
Simple example, using :meth:`unscale_` to enable clipping of unscaled gradients::
|
||||||
|
|
||||||
|
...
|
||||||
|
scaler.scale(loss).backward()
|
||||||
|
scaler.unscale_(optimizer)
|
||||||
|
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
|
||||||
|
scaler.step(optimizer)
|
||||||
|
scaler.update()
|
||||||
|
|
||||||
|
Args:
|
||||||
|
optimizer (torch.optim.Optimizer): Optimizer that owns the gradients to be unscaled.
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
:meth:`unscale_` does not incur a CPU-GPU sync.
|
||||||
|
|
||||||
|
.. warning::
|
||||||
|
:meth:`unscale_` should only be called once per optimizer per :meth:`step` call,
|
||||||
|
and only after all gradients for that optimizer's assigned parameters have been accumulated.
|
||||||
|
Calling :meth:`unscale_` twice for a given optimizer between each :meth:`step` triggers a RuntimeError.
|
||||||
|
|
||||||
|
.. warning::
|
||||||
|
:meth:`unscale_` may unscale sparse gradients out of place, replacing the ``.grad`` attribute.
|
||||||
|
"""
|
||||||
|
if not self._enabled:
|
||||||
|
return
|
||||||
|
|
||||||
|
self._check_scale_growth_tracker("unscale_")
|
||||||
|
|
||||||
|
optimizer_state = self._per_optimizer_states[id(optimizer)]
|
||||||
|
|
||||||
|
if optimizer_state["stage"] is OptState.UNSCALED:
|
||||||
|
raise RuntimeError(
|
||||||
|
"unscale_() has already been called on this optimizer since the last update().")
|
||||||
|
elif optimizer_state["stage"] is OptState.STEPPED:
|
||||||
|
raise RuntimeError("unscale_() is being called after step().")
|
||||||
|
|
||||||
|
# FP32 division can be imprecise for certain compile options, so we carry out the reciprocal in FP64.
|
||||||
|
assert self._scale is not None
|
||||||
|
inv_scale = self._scale.double().reciprocal().float()
|
||||||
|
found_inf = torch.full(
|
||||||
|
(1,), 0.0, dtype=torch.float32, device=self._scale.device)
|
||||||
|
|
||||||
|
optimizer_state["found_inf_per_device"] = self._unscale_grads_(
|
||||||
|
optimizer, inv_scale, found_inf, False)
|
||||||
|
optimizer_state["stage"] = OptState.UNSCALED
|
||||||
|
|
||||||
|
def _maybe_opt_step(self, optimizer, optimizer_state, *args, **kwargs):
|
||||||
|
retval = None
|
||||||
|
if not sum(v.item() for v in optimizer_state["found_inf_per_device"].values()):
|
||||||
|
retval = optimizer.step(*args, **kwargs)
|
||||||
|
return retval
|
||||||
|
|
||||||
|
def step(self, optimizer, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
:meth:`step` carries out the following two operations:
|
||||||
|
|
||||||
|
1. Internally invokes ``unscale_(optimizer)`` (unless :meth:`unscale_` was explicitly called for ``optimizer``
|
||||||
|
earlier in the iteration). As part of the :meth:`unscale_`, gradients are checked for infs/NaNs.
|
||||||
|
2. If no inf/NaN gradients are found, invokes ``optimizer.step()`` using the unscaled
|
||||||
|
gradients. Otherwise, ``optimizer.step()`` is skipped to avoid corrupting the params.
|
||||||
|
|
||||||
|
``*args`` and ``**kwargs`` are forwarded to ``optimizer.step()``.
|
||||||
|
|
||||||
|
Returns the return value of ``optimizer.step(*args, **kwargs)``.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
optimizer (torch.optim.Optimizer): Optimizer that applies the gradients.
|
||||||
|
args: Any arguments.
|
||||||
|
kwargs: Any keyword arguments.
|
||||||
|
|
||||||
|
.. warning::
|
||||||
|
Closure use is not currently supported.
|
||||||
|
"""
|
||||||
|
if (not self._enabled):
|
||||||
|
return optimizer.step(*args, **kwargs)
|
||||||
|
|
||||||
|
if "closure" in kwargs:
|
||||||
|
raise RuntimeError(
|
||||||
|
"Closure use is not currently supported if GradScaler is enabled.")
|
||||||
|
|
||||||
|
self._check_scale_growth_tracker("step")
|
||||||
|
|
||||||
|
optimizer_state = self._per_optimizer_states[id(optimizer)]
|
||||||
|
|
||||||
|
if optimizer_state["stage"] is OptState.STEPPED:
|
||||||
|
raise RuntimeError(
|
||||||
|
"step() has already been called since the last update().")
|
||||||
|
|
||||||
|
retval = None
|
||||||
|
|
||||||
|
if (hasattr(optimizer, "_step_supports_amp_scaling") and optimizer._step_supports_amp_scaling):
|
||||||
|
# This optimizer has customized scale-handling logic, so we can call optimizer.step() directly.
|
||||||
|
# The contract with custom optimizers is that their step() should accept an additional,
|
||||||
|
# optional grad_scaler kwarg. We append self to the kwargs so the custom optimizer has full information:
|
||||||
|
# it can query its own state, invoke unscale_ on itself, etc
|
||||||
|
retval = optimizer.step(*args, **dict(kwargs, grad_scaler=self))
|
||||||
|
optimizer_state["stage"] = OptState.STEPPED
|
||||||
|
return retval
|
||||||
|
|
||||||
|
if optimizer_state["stage"] is OptState.READY:
|
||||||
|
self.unscale_(optimizer)
|
||||||
|
|
||||||
|
assert len(optimizer_state["found_inf_per_device"]
|
||||||
|
) > 0, "No inf checks were recorded for this optimizer."
|
||||||
|
|
||||||
|
retval = self._maybe_opt_step(
|
||||||
|
optimizer, optimizer_state, *args, **kwargs)
|
||||||
|
|
||||||
|
optimizer_state["stage"] = OptState.STEPPED
|
||||||
|
|
||||||
|
return retval
|
||||||
|
|
||||||
|
def update(self, new_scale=None):
|
||||||
|
"""
|
||||||
|
Updates the scale factor.
|
||||||
|
|
||||||
|
If any optimizer steps were skipped the scale is multiplied by ``backoff_factor``
|
||||||
|
to reduce it. If ``growth_interval`` unskipped iterations occurred consecutively,
|
||||||
|
the scale is multiplied by ``growth_factor`` to increase it.
|
||||||
|
|
||||||
|
Passing ``new_scale`` sets the new scale value manually. (``new_scale`` is not
|
||||||
|
used directly, it's used to fill GradScaler's internal scale tensor. So if
|
||||||
|
``new_scale`` was a tensor, later in-place changes to that tensor will not further
|
||||||
|
affect the scale GradScaler uses internally.)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
new_scale (float or :class:`torch.cuda.FloatTensor`, optional, default=None): New scale factor.
|
||||||
|
|
||||||
|
.. warning::
|
||||||
|
:meth:`update` should only be called at the end of the iteration, after ``scaler.step(optimizer)`` has
|
||||||
|
been invoked for all optimizers used this iteration.
|
||||||
|
"""
|
||||||
|
if not self._enabled:
|
||||||
|
return
|
||||||
|
|
||||||
|
_scale, _growth_tracker = self._check_scale_growth_tracker("update")
|
||||||
|
|
||||||
|
if new_scale is not None:
|
||||||
|
# Accept a new user-defined scale.
|
||||||
|
if isinstance(new_scale, float):
|
||||||
|
self._scale.fill_(new_scale) # type: ignore[union-attr]
|
||||||
|
else:
|
||||||
|
reason = "new_scale should be a float or a 1-element torch.cuda.FloatTensor with requires_grad=False."
|
||||||
|
# type: ignore[attr-defined]
|
||||||
|
assert isinstance(new_scale, torch.cuda.FloatTensor), reason
|
||||||
|
assert new_scale.numel() == 1, reason
|
||||||
|
assert new_scale.requires_grad is False, reason
|
||||||
|
self._scale.copy_(new_scale) # type: ignore[union-attr]
|
||||||
|
else:
|
||||||
|
# Consume shared inf/nan data collected from optimizers to update the scale.
|
||||||
|
# If all found_inf tensors are on the same device as self._scale, this operation is asynchronous.
|
||||||
|
found_infs = [found_inf.to(device=_scale.device, non_blocking=True)
|
||||||
|
for state in self._per_optimizer_states.values()
|
||||||
|
for found_inf in state["found_inf_per_device"].values()]
|
||||||
|
|
||||||
|
assert len(
|
||||||
|
found_infs) > 0, "No inf checks were recorded prior to update."
|
||||||
|
|
||||||
|
found_inf_combined = found_infs[0]
|
||||||
|
if len(found_infs) > 1:
|
||||||
|
for i in range(1, len(found_infs)):
|
||||||
|
found_inf_combined += found_infs[i]
|
||||||
|
|
||||||
|
torch._amp_update_scale_(_scale,
|
||||||
|
_growth_tracker,
|
||||||
|
found_inf_combined,
|
||||||
|
self._growth_factor,
|
||||||
|
self._backoff_factor,
|
||||||
|
self._growth_interval)
|
||||||
|
|
||||||
|
# To prepare for next iteration, clear the data collected from optimizers this iteration.
|
||||||
|
self._per_optimizer_states = defaultdict(_refresh_per_optimizer_state)
|
||||||
|
|
||||||
|
def _get_scale_async(self):
|
||||||
|
return self._scale
|
||||||
|
|
||||||
|
def get_scale(self):
|
||||||
|
"""
|
||||||
|
Returns a Python float containing the current scale, or 1.0 if scaling is disabled.
|
||||||
|
|
||||||
|
.. warning::
|
||||||
|
:meth:`get_scale` incurs a CPU-GPU sync.
|
||||||
|
"""
|
||||||
|
if self._enabled:
|
||||||
|
return self._init_scale if self._scale is None else self._get_scale_async().item()
|
||||||
|
else:
|
||||||
|
return 1.0
|
||||||
|
|
||||||
|
def get_growth_factor(self):
|
||||||
|
r"""
|
||||||
|
Returns a Python float containing the scale growth factor.
|
||||||
|
"""
|
||||||
|
return self._growth_factor
|
||||||
|
|
||||||
|
def set_growth_factor(self, new_factor):
|
||||||
|
r"""
|
||||||
|
Args:
|
||||||
|
new_scale (float): Value to use as the new scale growth factor.
|
||||||
|
"""
|
||||||
|
self._growth_factor = new_factor
|
||||||
|
|
||||||
|
def get_backoff_factor(self):
|
||||||
|
r"""
|
||||||
|
Returns a Python float containing the scale backoff factor.
|
||||||
|
"""
|
||||||
|
return self._backoff_factor
|
||||||
|
|
||||||
|
def set_backoff_factor(self, new_factor):
|
||||||
|
r"""
|
||||||
|
Args:
|
||||||
|
new_scale (float): Value to use as the new scale backoff factor.
|
||||||
|
"""
|
||||||
|
self._backoff_factor = new_factor
|
||||||
|
|
||||||
|
def get_growth_interval(self):
|
||||||
|
r"""
|
||||||
|
Returns a Python int containing the growth interval.
|
||||||
|
"""
|
||||||
|
return self._growth_interval
|
||||||
|
|
||||||
|
def set_growth_interval(self, new_interval):
|
||||||
|
r"""
|
||||||
|
Args:
|
||||||
|
new_interval (int): Value to use as the new growth interval.
|
||||||
|
"""
|
||||||
|
self._growth_interval = new_interval
|
||||||
|
|
||||||
|
def _get_growth_tracker(self):
|
||||||
|
if self._enabled:
|
||||||
|
return self._init_growth_tracker if self._growth_tracker is None else self._growth_tracker.item()
|
||||||
|
else:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
def is_enabled(self):
|
||||||
|
r"""
|
||||||
|
Returns a bool indicating whether this instance is enabled.
|
||||||
|
"""
|
||||||
|
return self._enabled
|
||||||
|
|
||||||
|
def state_dict(self):
|
||||||
|
r"""
|
||||||
|
Returns the state of the scaler as a :class:`dict`. It contains five entries:
|
||||||
|
|
||||||
|
* ``"scale"`` - a Python float containing the current scale
|
||||||
|
* ``"growth_factor"`` - a Python float containing the current growth factor
|
||||||
|
* ``"backoff_factor"`` - a Python float containing the current backoff factor
|
||||||
|
* ``"growth_interval"`` - a Python int containing the current growth interval
|
||||||
|
* ``"_growth_tracker"`` - a Python int containing the number of recent consecutive unskipped steps.
|
||||||
|
|
||||||
|
If this instance is not enabled, returns an empty dict.
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
If you wish to checkpoint the scaler's state after a particular iteration, :meth:`state_dict`
|
||||||
|
should be called after :meth:`update`.
|
||||||
|
"""
|
||||||
|
return {"scale": self.get_scale(),
|
||||||
|
"growth_factor": self._growth_factor,
|
||||||
|
"backoff_factor": self._backoff_factor,
|
||||||
|
"growth_interval": self._growth_interval,
|
||||||
|
"_growth_tracker": self._get_growth_tracker()} if self._enabled else {}
|
||||||
|
|
||||||
|
def load_state_dict(self, state_dict):
|
||||||
|
r"""
|
||||||
|
Loads the scaler state. If this instance is disabled, :meth:`load_state_dict` is a no-op.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state_dict(dict): scaler state. Should be an object returned from a call to :meth:`state_dict`.
|
||||||
|
"""
|
||||||
|
if not self._enabled:
|
||||||
|
return
|
||||||
|
|
||||||
|
if len(state_dict) == 0:
|
||||||
|
raise RuntimeError("The source state dict is empty, possibly because it was saved "
|
||||||
|
"from a disabled instance of GradScaler.")
|
||||||
|
|
||||||
|
self._init_scale = state_dict["scale"]
|
||||||
|
if self._scale is not None:
|
||||||
|
self._scale.fill_(state_dict["scale"])
|
||||||
|
self._growth_factor = state_dict["growth_factor"]
|
||||||
|
self._backoff_factor = state_dict["backoff_factor"]
|
||||||
|
self._growth_interval = state_dict["growth_interval"]
|
||||||
|
self._init_growth_tracker = state_dict["_growth_tracker"]
|
||||||
|
if self._growth_tracker is not None:
|
||||||
|
self._growth_tracker.fill_(state_dict["_growth_tracker"])
|
||||||
|
|
||||||
|
def __getstate__(self):
|
||||||
|
state = self.__dict__.copy()
|
||||||
|
if self._enabled:
|
||||||
|
assert len(self._per_optimizer_states) == 0, "A GradScaler instance may only be pickled at the beginning "\
|
||||||
|
"of an iteration, or at the end after scaler.update()."
|
||||||
|
# Pickling _scale and _growth_tracker Tensors directly triggers
|
||||||
|
# "warnings.warn("pickle support for Storage will be removed in 1.5..."
|
||||||
|
# so instead, we set the unpickled instance up to reinitialize them lazily.
|
||||||
|
state['_init_scale'] = self.get_scale()
|
||||||
|
state['_init_growth_tracker'] = self._get_growth_tracker()
|
||||||
|
state['_scale'] = None
|
||||||
|
state['_growth_tracker'] = None
|
||||||
|
return state
|
||||||
|
|
||||||
|
def __setstate__(self, state):
|
||||||
|
self.__dict__.update(state)
|
||||||
|
|
||||||
|
def _check_inf_per_device(self, optimizer):
|
||||||
|
_scale, _ = self._check_scale_growth_tracker("_check_inf_per_device")
|
||||||
|
|
||||||
|
dummy_inv_scale = torch.full(
|
||||||
|
(1,), 1.0, dtype=torch.float32, device=_scale.device)
|
||||||
|
found_inf = torch.full(
|
||||||
|
(1,), 0.0, dtype=torch.float32, device=_scale.device)
|
||||||
|
|
||||||
|
self._per_optimizer_states[id(optimizer)]["found_inf_per_device"] = \
|
||||||
|
self._unscale_grads_(optimizer, dummy_inv_scale, found_inf, True)
|
||||||
|
|
||||||
|
return self._per_optimizer_states[id(optimizer)]["found_inf_per_device"]
|
||||||
|
|
||||||
|
def _found_inf_per_device(self, optimizer):
|
||||||
|
return self._per_optimizer_states[id(optimizer)]["found_inf_per_device"]
|
|
@ -5,125 +5,85 @@ from abc import ABC, abstractmethod
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from colossalai.core import global_context as gpc
|
||||||
from colossalai.logging import get_global_dist_logger
|
from colossalai.logging import get_global_dist_logger
|
||||||
from colossalai.utils import get_current_device
|
from colossalai.utils import get_current_device
|
||||||
|
|
||||||
|
|
||||||
class BaseSchedule(ABC):
|
class BaseSchedule(ABC):
|
||||||
"""A basic helper class to control the process of training or evaluation.
|
"""A basic helper class to control the process of training or evaluation.
|
||||||
|
It mainly composes of forward_backward_step for gradient backward and
|
||||||
|
optimizer_step for parameters update.
|
||||||
|
For the convenience to enable FP16, we aggreate all codes that contain the
|
||||||
|
control of FP16 in class schedule.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.initialized = False
|
|
||||||
self.logger = get_global_dist_logger()
|
self.logger = get_global_dist_logger()
|
||||||
|
|
||||||
@property
|
@staticmethod
|
||||||
@abstractmethod
|
def _move_tensor(element):
|
||||||
def num_steps(self):
|
if torch.is_tensor(element):
|
||||||
"""The number of batches in training or evaluation.
|
if not element.is_cuda:
|
||||||
"""
|
return element.to(get_current_device()).detach()
|
||||||
pass
|
return element
|
||||||
|
|
||||||
def initialize(self,
|
def _move_to_device(self, data):
|
||||||
dataloader=None,
|
if isinstance(data, (tuple, list)):
|
||||||
model=None,
|
data = tuple([self._move_tensor(d) for d in data])
|
||||||
criterion=None,
|
elif torch.is_tensor(data):
|
||||||
optimizer=None,
|
data = data.to(get_current_device()).detach()
|
||||||
lr_scheduler=None):
|
return data
|
||||||
"""Initializes the schedule and set parameters before running.
|
|
||||||
|
|
||||||
:param dataloader: DataLoader in training or evaluation
|
def load_batch(self, data_iter):
|
||||||
:param model: The neural network model
|
"""Loads a batch from data iterator. It returns the data and labels which are
|
||||||
:param criterion: Criterion for calculating loss
|
|
||||||
:param optimizer: Optimizer for updating the parameters
|
|
||||||
:param lr_scheduler: Learning rate scheduler in the process
|
|
||||||
"""
|
|
||||||
self.dataloader = dataloader
|
|
||||||
assert model is not None, "Schedule requires a model"
|
|
||||||
self.model = model
|
|
||||||
assert criterion is not None, "Schedule requires a criterion"
|
|
||||||
self.criterion = criterion
|
|
||||||
assert optimizer is not None, "Schedule requires an optimizer"
|
|
||||||
self.optimizer = optimizer
|
|
||||||
self.lr_scheduler = lr_scheduler
|
|
||||||
self.initialized = True
|
|
||||||
|
|
||||||
def check_initialized(self):
|
|
||||||
"""Checks whether the schedule is initialized.
|
|
||||||
"""
|
|
||||||
assert self.initialized, \
|
|
||||||
'Schedule is not initialized. Call schedule.initialize(...) before using it.'
|
|
||||||
|
|
||||||
def load_batch(self):
|
|
||||||
"""Loads a batch of dataset. It returns the data and labels which are
|
|
||||||
already in the same GPU as where the model's.
|
already in the same GPU as where the model's.
|
||||||
|
|
||||||
:return: (data, label)
|
:return: (data, label)
|
||||||
:rtype: (Tensor, Tensor)
|
:rtype: (Tensor, Tensor)
|
||||||
"""
|
"""
|
||||||
self.check_initialized()
|
if data_iter is None:
|
||||||
if self.data_iter is None:
|
|
||||||
raise RuntimeError('Dataloader is not defined.')
|
raise RuntimeError('Dataloader is not defined.')
|
||||||
data, label = next(self.data_iter)
|
data, label = next(data_iter)
|
||||||
return self._move_to_device(data), self._move_to_device(label)
|
return self._move_to_device(data), self._move_to_device(label)
|
||||||
|
|
||||||
def _move_to_device(self, data):
|
def initialize(self, model, optimizer):
|
||||||
if isinstance(data, (
|
"""Initializes the model and the optimizer before training.
|
||||||
tuple,
|
This is often used in FP16 training.
|
||||||
list,
|
|
||||||
)):
|
|
||||||
data = tuple([
|
|
||||||
d.to(get_current_device()).detach() for d in data
|
|
||||||
if torch.is_tensor(d)
|
|
||||||
])
|
|
||||||
elif torch.is_tensor(data):
|
|
||||||
data = data.to(get_current_device()).detach()
|
|
||||||
return data
|
|
||||||
|
|
||||||
def train(self, dataloader=None, mode=True):
|
:param model: The neural network model
|
||||||
"""Sets the dataloader to be used and turn the model to
|
:param optimizer: Optimizer for updating the parameters
|
||||||
training or evaluation mode.
|
|
||||||
|
|
||||||
:param dataloader: Dataloader to be used
|
|
||||||
:param mode: If True, the model will set as training mode. Otherwise, evaluation mode.
|
|
||||||
"""
|
"""
|
||||||
self.check_initialized()
|
return model, optimizer
|
||||||
if mode:
|
|
||||||
self.model.train()
|
|
||||||
else:
|
|
||||||
self.model.eval()
|
|
||||||
if dataloader is not None:
|
|
||||||
self.dataloader = dataloader
|
|
||||||
self.data_iter = iter(dataloader)
|
|
||||||
|
|
||||||
def zero_grad(self, forward_only=False):
|
|
||||||
"""Cleans gradients with the optimizer.
|
|
||||||
"""
|
|
||||||
if not forward_only:
|
|
||||||
self.check_initialized()
|
|
||||||
self.optimizer.zero_grad()
|
|
||||||
|
|
||||||
def get_lr(self):
|
|
||||||
"""Returns the current learning rate.
|
|
||||||
"""
|
|
||||||
if self.lr_scheduler is not None:
|
|
||||||
return self.lr_scheduler.get_lr()[0]
|
|
||||||
else:
|
|
||||||
return self.optimizer.param_groups[0]['lr']
|
|
||||||
|
|
||||||
def step(self):
|
|
||||||
"""Updates the parameters and learning rate with the optimizer.
|
|
||||||
"""
|
|
||||||
self.check_initialized()
|
|
||||||
self.optimizer.step()
|
|
||||||
# update lr scheduler
|
|
||||||
if self.lr_scheduler is not None:
|
|
||||||
self.lr_scheduler.step()
|
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def forward_backward_step(self, forward_only=False, return_loss=True):
|
def forward_backward_step(self,
|
||||||
|
data_iter,
|
||||||
|
model,
|
||||||
|
criterion,
|
||||||
|
optimizer=None,
|
||||||
|
forward_only=False,
|
||||||
|
grad_accum_size: int = 1,
|
||||||
|
return_loss=True):
|
||||||
"""The process function over a batch of dataset for training or evaluation.
|
"""The process function over a batch of dataset for training or evaluation.
|
||||||
|
|
||||||
:param forward_only: If True, the process won't include backward.
|
:param data_iter: Data iterator of the dataset
|
||||||
:param return_loss: If False, the loss won't be returned.
|
:param model: Model used in training or evaluation
|
||||||
|
:param optimizer: Optimizer used in training or evaluation
|
||||||
|
:param criterion: Loss function
|
||||||
|
:param forward_only: If True, the process won't include backward
|
||||||
|
:param grad_accum_size: Steps of gradient accumulation
|
||||||
|
:param return_loss: If False, the loss won't be returned
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def optimizer_step(self, model, optimizer, grad_clipping: float = 0.0):
|
||||||
|
"""Updates the parameters with the optimizer.
|
||||||
|
|
||||||
|
:param model: The neural network model
|
||||||
|
:param optimizer: Optimizer for updating the parameters
|
||||||
|
:param grad_clipping: The norm of gradient clipping
|
||||||
|
:type grad_clipping: float, optional
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
|
@ -4,19 +4,24 @@
|
||||||
try:
|
try:
|
||||||
import apex.amp as apex_amp
|
import apex.amp as apex_amp
|
||||||
except:
|
except:
|
||||||
print('apex is required for mixed precision training')
|
pass
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import torch.cuda.amp as torch_amp
|
import torch.cuda.amp as torch_amp
|
||||||
except:
|
except:
|
||||||
print('PyTorch amp is not supported with the current PyTorch version')
|
pass
|
||||||
|
|
||||||
|
from typing import Iterable
|
||||||
|
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch.optim import Optimizer
|
||||||
|
|
||||||
from colossalai.context import ParallelMode
|
|
||||||
from colossalai.core import global_context as gpc
|
|
||||||
from colossalai.engine.amp_type import AMP_TYPE
|
|
||||||
from colossalai.nn import (ZeroRedundancyOptimizer_Level_2,
|
from colossalai.nn import (ZeroRedundancyOptimizer_Level_2,
|
||||||
ZeroRedundancyOptimizer_Level_3)
|
ZeroRedundancyOptimizer_Level_3)
|
||||||
from ._utils import convert_to_fp16
|
from colossalai.nn.optimizer._utils import clip_grad_norm_fp32
|
||||||
from ._base_schedule import BaseSchedule
|
from ._base_schedule import BaseSchedule
|
||||||
|
from ._utils import convert_to_fp16, convert_to_fp32
|
||||||
|
from ..amp import AMP_TYPE, GradScaler
|
||||||
|
|
||||||
|
|
||||||
class NoPipelineSchedule(BaseSchedule):
|
class NoPipelineSchedule(BaseSchedule):
|
||||||
|
@ -30,6 +35,7 @@ class NoPipelineSchedule(BaseSchedule):
|
||||||
:type amp_type: AMP_TYPE
|
:type amp_type: AMP_TYPE
|
||||||
:type amp_config: dict
|
:type amp_config: dict
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
amp_type: AMP_TYPE = None,
|
amp_type: AMP_TYPE = None,
|
||||||
|
@ -41,12 +47,6 @@ class NoPipelineSchedule(BaseSchedule):
|
||||||
assert amp_type is None or isinstance(amp_type, AMP_TYPE), \
|
assert amp_type is None or isinstance(amp_type, AMP_TYPE), \
|
||||||
'unrecognised value for argument fp16, it can only be None, torch or apex'
|
'unrecognised value for argument fp16, it can only be None, torch or apex'
|
||||||
|
|
||||||
# LSG: check compatibility
|
|
||||||
# LSG: torch.cuda.amp and apex.amp cannot be used for tensor parallel
|
|
||||||
if gpc.is_initialized(ParallelMode.TENSOR) and gpc.get_world_size(
|
|
||||||
ParallelMode.TENSOR) > 1:
|
|
||||||
assert amp_type != AMP_TYPE.TORCH and amp_type != AMP_TYPE.APEX, \
|
|
||||||
'You can only AMP_TYPE.PARALLEL for tensor parallel training'
|
|
||||||
self.use_zero_level_2_3 = False
|
self.use_zero_level_2_3 = False
|
||||||
|
|
||||||
if amp_type is not None:
|
if amp_type is not None:
|
||||||
|
@ -79,107 +79,110 @@ class NoPipelineSchedule(BaseSchedule):
|
||||||
self.fp16 = False
|
self.fp16 = False
|
||||||
self.amp_type = None
|
self.amp_type = None
|
||||||
|
|
||||||
@property
|
def initialize(self, model: nn.Module, optimizer: Optimizer):
|
||||||
def num_steps(self):
|
if isinstance(optimizer, (ZeroRedundancyOptimizer_Level_2,
|
||||||
return len(self.dataloader)
|
|
||||||
|
|
||||||
def initialize(self,
|
|
||||||
dataloader,
|
|
||||||
model,
|
|
||||||
criterion,
|
|
||||||
optimizer,
|
|
||||||
lr_scheduler=None):
|
|
||||||
super().initialize(dataloader,
|
|
||||||
model,
|
|
||||||
criterion,
|
|
||||||
optimizer,
|
|
||||||
lr_scheduler=lr_scheduler)
|
|
||||||
if isinstance(self.optimizer, (ZeroRedundancyOptimizer_Level_2,
|
|
||||||
ZeroRedundancyOptimizer_Level_3)):
|
ZeroRedundancyOptimizer_Level_3)):
|
||||||
self.use_zero_level_2_3 = True
|
self.use_zero_level_2_3 = True
|
||||||
assert self.amp_type != AMP_TYPE.PARALLEL, 'ZeRO Level 2 and 3 are mutually exclusive with AMP_TYPE.PARALLEL'
|
assert self.amp_type != AMP_TYPE.PARALLEL, \
|
||||||
|
'ZeRO Level 2 and 3 are mutually exclusive with AMP_TYPE.PARALLEL'
|
||||||
|
|
||||||
if self.fp16:
|
if self.fp16:
|
||||||
if self.amp_type == AMP_TYPE.TORCH:
|
if self.amp_type == AMP_TYPE.TORCH:
|
||||||
self._torch_amp_scaler = torch_amp.GradScaler(**self.amp_cfg)
|
self._torch_amp_scaler = GradScaler(**self.amp_cfg)
|
||||||
elif self.amp_type == AMP_TYPE.APEX:
|
elif self.amp_type == AMP_TYPE.APEX:
|
||||||
self.model, self.optimizer = apex_amp.initialize(
|
model, optimizer = apex_amp.initialize(model, optimizer, **self.amp_cfg)
|
||||||
self.model, self.optimizer, **self.amp_cfg)
|
|
||||||
|
|
||||||
def forward_backward_step(self, forward_only=False, return_loss=True):
|
return model, optimizer
|
||||||
|
|
||||||
|
def forward_backward_step(self,
|
||||||
|
data_iter: Iterable,
|
||||||
|
model: nn.Module,
|
||||||
|
criterion: nn.modules.loss._Loss,
|
||||||
|
optimizer: Optimizer = None,
|
||||||
|
forward_only: bool = False,
|
||||||
|
grad_accum_size: int = 1,
|
||||||
|
return_loss: bool = True):
|
||||||
"""The process function that loads loads a batch of dataset and feeds it to the model.
|
"""The process function that loads loads a batch of dataset and feeds it to the model.
|
||||||
The returned labels and loss will None if :attr:`return_loss` is False.
|
The returned labels and loss will None if :attr:`return_loss` is False.
|
||||||
|
|
||||||
|
:param data_iter: Data iterator of the dataloader, e.g. iter(dataloader)
|
||||||
|
:param model: Model for training and inference
|
||||||
|
:param criterion: Loss function for training
|
||||||
|
:param optimizer: Optimizer used for training
|
||||||
|
:param forward_only: If True, the model is run for the forward pass, else back propagation will be executed
|
||||||
|
:param grad_accum_size: The number of iterations for gradient accumulation
|
||||||
|
:param return_loss: Loss will be returned if True
|
||||||
|
:type data_iter: Iterator
|
||||||
|
:type model: torch.nn.Module
|
||||||
|
:type criterion: torch.nn.modules.loss._Loss
|
||||||
|
:type optimizer: torch.optim.Optimizer
|
||||||
|
:type forward_only: bool, optional
|
||||||
|
:type grad_accum_size: int
|
||||||
|
:type return_loss: bool, optional
|
||||||
:return: (output, label, loss)
|
:return: (output, label, loss)
|
||||||
"""
|
"""
|
||||||
assert forward_only or return_loss, \
|
assert forward_only or return_loss, \
|
||||||
'The argument \'return_loss\' has to be True when \'forward_only\' is False, but got False.'
|
'The argument \'return_loss\' has to be True when \'forward_only\' is False, but got False.'
|
||||||
|
|
||||||
data, label = self.load_batch()
|
data, label = self.load_batch(data_iter)
|
||||||
loss = None
|
loss = None
|
||||||
|
|
||||||
# LSG: leave for debug, make sure dataloader is deterministic
|
|
||||||
# if forward_only:
|
|
||||||
# img = data[0]
|
|
||||||
# rank = gpc.get_local_rank(ParallelMode.DATA)
|
|
||||||
# world_size = gpc.get_world_size(ParallelMode.DATA)
|
|
||||||
# group = gpc.get_group(ParallelMode.DATA)
|
|
||||||
# input_list = [img.clone() for _ in range(world_size)]
|
|
||||||
# output_list = [torch.empty_like(img) for _ in range(world_size)]
|
|
||||||
# output_list[rank] = img.clone()
|
|
||||||
# dist.all_to_all(output_tensor_list=output_list, input_tensor_list=input_list, group=group)
|
|
||||||
# assert torch.equal(output_list[0], output_list[1]) # and torch.equal(output_list[1], output_list[2])
|
|
||||||
|
|
||||||
# forward
|
# forward
|
||||||
if self.fp16 and self.amp_type == AMP_TYPE.TORCH:
|
if self.fp16 and self.amp_type == AMP_TYPE.TORCH:
|
||||||
with torch_amp.autocast():
|
with torch_amp.autocast():
|
||||||
output = self.model(*data)
|
output = model(*data)
|
||||||
if not isinstance(output, (tuple, list)):
|
if not isinstance(output, (tuple, list)):
|
||||||
output = (output,)
|
output = (output,)
|
||||||
if return_loss:
|
if return_loss:
|
||||||
loss = self.criterion(*output, *label)
|
loss = criterion(*output, *label)
|
||||||
else:
|
else:
|
||||||
if self.use_zero_level_2_3 or self.amp_type == AMP_TYPE.PARALLEL:
|
if self.use_zero_level_2_3 or self.amp_type == AMP_TYPE.PARALLEL:
|
||||||
data = convert_to_fp16(data)
|
data = convert_to_fp16(data)
|
||||||
|
|
||||||
output = self.model(*data)
|
output = model(*data)
|
||||||
|
|
||||||
|
if self.use_zero_level_2_3 or self.amp_type == AMP_TYPE.PARALLEL:
|
||||||
|
output = convert_to_fp32(output)
|
||||||
|
|
||||||
if not isinstance(output, (tuple, list)):
|
if not isinstance(output, (tuple, list)):
|
||||||
output = (output,)
|
output = (output,)
|
||||||
if return_loss:
|
if return_loss:
|
||||||
loss = self.criterion(*output, *label)
|
loss = criterion(*output, *label)
|
||||||
|
|
||||||
|
loss /= grad_accum_size
|
||||||
|
|
||||||
if not forward_only:
|
if not forward_only:
|
||||||
# backward
|
# backward
|
||||||
if self.use_zero_level_2_3:
|
if self.use_zero_level_2_3:
|
||||||
self.optimizer.backward(loss)
|
optimizer.backward(loss)
|
||||||
elif self.fp16:
|
elif self.fp16:
|
||||||
if self.amp_type == AMP_TYPE.APEX:
|
if self.amp_type == AMP_TYPE.APEX:
|
||||||
with apex_amp.scale_loss(loss,
|
with apex_amp.scale_loss(loss, optimizer) as scaled_loss:
|
||||||
self.optimizer) as scaled_loss:
|
|
||||||
scaled_loss.backward()
|
scaled_loss.backward()
|
||||||
elif self.amp_type == AMP_TYPE.TORCH:
|
elif self.amp_type == AMP_TYPE.TORCH:
|
||||||
self._torch_amp_scaler.scale(loss).backward()
|
self._torch_amp_scaler.scale(loss).backward()
|
||||||
elif self.amp_type == AMP_TYPE.PARALLEL:
|
elif self.amp_type == AMP_TYPE.PARALLEL:
|
||||||
loss = self.optimizer.scale_loss(loss)
|
loss = optimizer.scale_loss(loss)
|
||||||
loss.backward()
|
loss.backward()
|
||||||
# scale back to display the original value in logs
|
# scale back to display the original value in logs
|
||||||
loss.div_(self.optimizer.grad_scaler.scale)
|
loss.div_(optimizer.grad_scaler.scale)
|
||||||
else:
|
else:
|
||||||
loss.backward()
|
loss.backward()
|
||||||
|
|
||||||
if return_loss:
|
if return_loss:
|
||||||
return output, label, loss
|
return output, label, loss * grad_accum_size
|
||||||
else:
|
else:
|
||||||
return output, None, None
|
return output, None, None
|
||||||
|
|
||||||
def step(self):
|
def optimizer_step(self, model: nn.Module, optimizer: Optimizer, grad_clipping: float = 0.0):
|
||||||
# step optimizer
|
# step optimizer
|
||||||
if self.fp16 and self.amp_type == AMP_TYPE.TORCH:
|
if self.fp16 and self.amp_type == AMP_TYPE.TORCH:
|
||||||
self._torch_amp_scaler.step(self.optimizer)
|
if grad_clipping > 0.0:
|
||||||
|
self._torch_amp_scaler.unscale_(optimizer)
|
||||||
|
clip_grad_norm_fp32(model.parameters(), grad_clipping)
|
||||||
|
self._torch_amp_scaler.step(optimizer)
|
||||||
self._torch_amp_scaler.update()
|
self._torch_amp_scaler.update()
|
||||||
else:
|
else:
|
||||||
self.optimizer.step()
|
if not self.fp16 and not self.use_zero_level_2_3 and grad_clipping > 0.0:
|
||||||
|
clip_grad_norm_fp32(model.parameters(), grad_clipping)
|
||||||
# update lr scheduler
|
optimizer.step()
|
||||||
if self.lr_scheduler is not None:
|
|
||||||
self.lr_scheduler.step()
|
|
||||||
|
|
|
@ -15,7 +15,7 @@ from colossalai.nn import (ZeroRedundancyOptimizer_Level_2,
|
||||||
from colossalai.utils import get_current_device
|
from colossalai.utils import get_current_device
|
||||||
from ._base_schedule import BaseSchedule
|
from ._base_schedule import BaseSchedule
|
||||||
from ._utils import convert_to_fp16
|
from ._utils import convert_to_fp16
|
||||||
from ..amp_type import AMP_TYPE
|
from ..amp import AMP_TYPE
|
||||||
|
|
||||||
|
|
||||||
def squeeze(x: Union[Tensor, tuple, list]):
|
def squeeze(x: Union[Tensor, tuple, list]):
|
||||||
|
@ -93,12 +93,11 @@ class PipelineSchedule(BaseSchedule):
|
||||||
)
|
)
|
||||||
|
|
||||||
# Pipeline schedule just puts data in memory
|
# Pipeline schedule just puts data in memory
|
||||||
def load_batch(self):
|
def load_batch(self, data_iter):
|
||||||
self.check_initialized()
|
if data_iter is None:
|
||||||
if self.data_iter is None:
|
|
||||||
raise RuntimeError('Dataloader is not defined.')
|
raise RuntimeError('Dataloader is not defined.')
|
||||||
self.batch_pos = 0
|
self.batch_pos = 0
|
||||||
data, label = next(self.data_iter)
|
data, label = next(data_iter)
|
||||||
self.batch_data, self.batch_label = \
|
self.batch_data, self.batch_label = \
|
||||||
self._move_to_device(data), self._move_to_device(label)
|
self._move_to_device(data), self._move_to_device(label)
|
||||||
batch_size = self.batch_data.shape[0]
|
batch_size = self.batch_data.shape[0]
|
||||||
|
@ -117,23 +116,8 @@ class PipelineSchedule(BaseSchedule):
|
||||||
self.batch_pos += self.microbatch_size
|
self.batch_pos += self.microbatch_size
|
||||||
return (data,), (label,)
|
return (data,), (label,)
|
||||||
|
|
||||||
@property
|
def initialize(self, model, optimizer):
|
||||||
def num_steps(self):
|
if isinstance(optimizer, (ZeroRedundancyOptimizer_Level_2, ZeroRedundancyOptimizer_Level_3)):
|
||||||
return len(self.dataloader)
|
|
||||||
|
|
||||||
def initialize(self,
|
|
||||||
dataloader,
|
|
||||||
model,
|
|
||||||
criterion,
|
|
||||||
optimizer,
|
|
||||||
lr_scheduler=None):
|
|
||||||
super().initialize(dataloader,
|
|
||||||
model,
|
|
||||||
criterion,
|
|
||||||
optimizer,
|
|
||||||
lr_scheduler=lr_scheduler)
|
|
||||||
if isinstance(self.optimizer, (ZeroRedundancyOptimizer_Level_2,
|
|
||||||
ZeroRedundancyOptimizer_Level_3)):
|
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
"Pipeline schedule is currently not compatible with ZeRO Level 2 and Level 3"
|
"Pipeline schedule is currently not compatible with ZeRO Level 2 and Level 3"
|
||||||
)
|
)
|
||||||
|
@ -145,7 +129,8 @@ class PipelineSchedule(BaseSchedule):
|
||||||
'default tensor dtype is set to torch.half for fp16 training',
|
'default tensor dtype is set to torch.half for fp16 training',
|
||||||
ranks=[0])
|
ranks=[0])
|
||||||
|
|
||||||
def forward_step(self, input_tensor, return_tensors, return_loss=True):
|
def forward_step(self, model, criterion, input_tensor, return_tensors,
|
||||||
|
grad_accum_size, return_loss=True):
|
||||||
"""Forward step for passed-in model. If it is the first stage, the input tensor
|
"""Forward step for passed-in model. If it is the first stage, the input tensor
|
||||||
is obtained from data_iterator, otherwise the passed-in input_tensor is used.
|
is obtained from data_iterator, otherwise the passed-in input_tensor is used.
|
||||||
Returns output tensor. This is a helper function and can be ignored by users.
|
Returns output tensor. This is a helper function and can be ignored by users.
|
||||||
|
@ -156,14 +141,14 @@ class PipelineSchedule(BaseSchedule):
|
||||||
if self.amp_type == AMP_TYPE.PARALLEL:
|
if self.amp_type == AMP_TYPE.PARALLEL:
|
||||||
input_tensor = convert_to_fp16(input_tensor)
|
input_tensor = convert_to_fp16(input_tensor)
|
||||||
input_tensor = squeeze(input_tensor)
|
input_tensor = squeeze(input_tensor)
|
||||||
output_tensor = self.model(input_tensor)
|
output_tensor = model(input_tensor)
|
||||||
output_tensor = squeeze(output_tensor)
|
output_tensor = squeeze(output_tensor)
|
||||||
|
|
||||||
if gpc.is_last_rank(ParallelMode.PIPELINE):
|
if gpc.is_last_rank(ParallelMode.PIPELINE):
|
||||||
if return_loss:
|
if return_loss:
|
||||||
input_tensor, label = self.load_micro_batch()
|
input_tensor, label = self.load_micro_batch()
|
||||||
loss_reduced = self.criterion(output_tensor, *
|
loss_reduced = criterion(output_tensor, *label) \
|
||||||
label) / self.num_microbatches
|
/ (self.num_microbatches * grad_accum_size)
|
||||||
return_tensors.append(
|
return_tensors.append(
|
||||||
tuple((output_tensor, label[0], loss_reduced)))
|
tuple((output_tensor, label[0], loss_reduced)))
|
||||||
return loss_reduced
|
return loss_reduced
|
||||||
|
@ -174,7 +159,7 @@ class PipelineSchedule(BaseSchedule):
|
||||||
else:
|
else:
|
||||||
return output_tensor
|
return output_tensor
|
||||||
|
|
||||||
def backward_step(self, input_tensor, output_tensor, output_tensor_grad):
|
def backward_step(self, optimizer, input_tensor, output_tensor, output_tensor_grad):
|
||||||
"""Backward step through the passed-in output tensor. If it is the last stage, the
|
"""Backward step through the passed-in output tensor. If it is the last stage, the
|
||||||
output_tensor_grad is None, otherwise it is the gradients with respect to stage's output tensor.
|
output_tensor_grad is None, otherwise it is the gradients with respect to stage's output tensor.
|
||||||
Returns the gradients with respect to the input tensor (None if first stage).
|
Returns the gradients with respect to the input tensor (None if first stage).
|
||||||
|
@ -187,7 +172,7 @@ class PipelineSchedule(BaseSchedule):
|
||||||
|
|
||||||
# Backward pass.
|
# Backward pass.
|
||||||
if output_tensor_grad is None and self.amp_type == AMP_TYPE.PARALLEL:
|
if output_tensor_grad is None and self.amp_type == AMP_TYPE.PARALLEL:
|
||||||
output_tensor = self.optimizer.scale_loss(output_tensor)
|
output_tensor = optimizer.scale_loss(output_tensor)
|
||||||
torch.autograd.backward(output_tensor, grad_tensors=output_tensor_grad)
|
torch.autograd.backward(output_tensor, grad_tensors=output_tensor_grad)
|
||||||
|
|
||||||
# Collect the grad of the input_tensor.
|
# Collect the grad of the input_tensor.
|
||||||
|
@ -197,7 +182,14 @@ class PipelineSchedule(BaseSchedule):
|
||||||
|
|
||||||
return input_tensor_grad
|
return input_tensor_grad
|
||||||
|
|
||||||
def forward_backward_step(self, forward_only=True, return_loss=True):
|
def forward_backward_step(self,
|
||||||
|
data_iter,
|
||||||
|
model,
|
||||||
|
criterion,
|
||||||
|
optimizer=None,
|
||||||
|
forward_only=False,
|
||||||
|
grad_accum_size: int = 1,
|
||||||
|
return_loss=True):
|
||||||
"""Runs non-interleaved 1F1B schedule, with communication between pipeline stages.
|
"""Runs non-interleaved 1F1B schedule, with communication between pipeline stages.
|
||||||
Returns a tuple with losses if the last stage, an empty tuple otherwise.
|
Returns a tuple with losses if the last stage, an empty tuple otherwise.
|
||||||
|
|
||||||
|
@ -207,7 +199,7 @@ class PipelineSchedule(BaseSchedule):
|
||||||
assert forward_only or return_loss, \
|
assert forward_only or return_loss, \
|
||||||
'The argument \'return_loss\' has to be True when \'forward_only\' is False, but got False.'
|
'The argument \'return_loss\' has to be True when \'forward_only\' is False, but got False.'
|
||||||
|
|
||||||
self.load_batch()
|
self.load_batch(data_iter)
|
||||||
num_warmup_microbatches = \
|
num_warmup_microbatches = \
|
||||||
(gpc.get_world_size(ParallelMode.PIPELINE) -
|
(gpc.get_world_size(ParallelMode.PIPELINE) -
|
||||||
gpc.get_local_rank(ParallelMode.PIPELINE) - 1)
|
gpc.get_local_rank(ParallelMode.PIPELINE) - 1)
|
||||||
|
@ -233,9 +225,11 @@ class PipelineSchedule(BaseSchedule):
|
||||||
if not gpc.is_first_rank(ParallelMode.PIPELINE):
|
if not gpc.is_first_rank(ParallelMode.PIPELINE):
|
||||||
ft_shape = recv_tensor_meta(ft_shape)
|
ft_shape = recv_tensor_meta(ft_shape)
|
||||||
input_tensor = recv_forward(ft_shape)
|
input_tensor = recv_forward(ft_shape)
|
||||||
output_tensor = self.forward_step(input_tensor,
|
output_tensor = self.forward_step(
|
||||||
return_tensors,
|
model, criterion,
|
||||||
return_loss=return_loss)
|
input_tensor, return_tensors,
|
||||||
|
grad_accum_size, return_loss=return_loss
|
||||||
|
)
|
||||||
if not gpc.is_last_rank(ParallelMode.PIPELINE):
|
if not gpc.is_last_rank(ParallelMode.PIPELINE):
|
||||||
bt_shape = output_tensor.shape
|
bt_shape = output_tensor.shape
|
||||||
fs_checker = send_tensor_meta(output_tensor, fs_checker)
|
fs_checker = send_tensor_meta(output_tensor, fs_checker)
|
||||||
|
@ -257,9 +251,11 @@ class PipelineSchedule(BaseSchedule):
|
||||||
for i in range(num_microbatches_remaining):
|
for i in range(num_microbatches_remaining):
|
||||||
last_iteration = (i == (num_microbatches_remaining - 1))
|
last_iteration = (i == (num_microbatches_remaining - 1))
|
||||||
|
|
||||||
output_tensor = self.forward_step(input_tensor,
|
output_tensor = self.forward_step(
|
||||||
return_tensors,
|
model, criterion,
|
||||||
return_loss=return_loss)
|
input_tensor, return_tensors,
|
||||||
|
grad_accum_size, return_loss=return_loss
|
||||||
|
)
|
||||||
if forward_only:
|
if forward_only:
|
||||||
send_forward(output_tensor)
|
send_forward(output_tensor)
|
||||||
|
|
||||||
|
@ -279,9 +275,11 @@ class PipelineSchedule(BaseSchedule):
|
||||||
input_tensor = input_tensors.pop(0)
|
input_tensor = input_tensors.pop(0)
|
||||||
output_tensor = output_tensors.pop(0)
|
output_tensor = output_tensors.pop(0)
|
||||||
|
|
||||||
input_tensor_grad = self.backward_step(input_tensor,
|
input_tensor_grad = self.backward_step(
|
||||||
output_tensor,
|
optimizer,
|
||||||
output_tensor_grad)
|
input_tensor, output_tensor,
|
||||||
|
output_tensor_grad
|
||||||
|
)
|
||||||
|
|
||||||
if last_iteration:
|
if last_iteration:
|
||||||
input_tensor = None
|
input_tensor = None
|
||||||
|
@ -298,9 +296,11 @@ class PipelineSchedule(BaseSchedule):
|
||||||
|
|
||||||
output_tensor_grad = recv_backward(bt_shape)
|
output_tensor_grad = recv_backward(bt_shape)
|
||||||
|
|
||||||
input_tensor_grad = self.backward_step(input_tensor,
|
input_tensor_grad = self.backward_step(
|
||||||
output_tensor,
|
optimizer,
|
||||||
output_tensor_grad)
|
input_tensor, output_tensor,
|
||||||
|
output_tensor_grad
|
||||||
|
)
|
||||||
|
|
||||||
send_backward(input_tensor_grad)
|
send_backward(input_tensor_grad)
|
||||||
|
|
||||||
|
@ -309,8 +309,11 @@ class PipelineSchedule(BaseSchedule):
|
||||||
output, label, loss = tuple(map(list, zip(*return_tensors)))
|
output, label, loss = tuple(map(list, zip(*return_tensors)))
|
||||||
return (torch.cat(output, dim=0),
|
return (torch.cat(output, dim=0),
|
||||||
torch.cat(label, dim=0),
|
torch.cat(label, dim=0),
|
||||||
sum(loss))
|
sum(loss) * grad_accum_size)
|
||||||
else:
|
else:
|
||||||
return tuple((torch.cat(return_tensors, dim=0), None, None))
|
return tuple((torch.cat(return_tensors, dim=0), None, None))
|
||||||
else:
|
else:
|
||||||
return tuple((None, None, None))
|
return tuple((None, None, None))
|
||||||
|
|
||||||
|
def optimizer_step(self, model, optimizer, grad_clipping: float = 0.0):
|
||||||
|
optimizer.step()
|
||||||
|
|
|
@ -14,3 +14,14 @@ def convert_to_fp16(data: Union[Tensor, List[Tensor]]):
|
||||||
else:
|
else:
|
||||||
raise TypeError(f"Expected argument 'data' to be a Tensor or a list/tuple of Tensor, but got {type(data)}")
|
raise TypeError(f"Expected argument 'data' to be a Tensor or a list/tuple of Tensor, but got {type(data)}")
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
|
||||||
|
def convert_to_fp32(data: Union[Tensor, List[Tensor]]):
|
||||||
|
if isinstance(data, Tensor):
|
||||||
|
ret = data.float()
|
||||||
|
elif isinstance(data, (list, tuple)):
|
||||||
|
ret = [val.float() for val in data]
|
||||||
|
else:
|
||||||
|
raise TypeError(f"Expected argument 'data' to be a Tensor or a list/tuple of Tensor, but got {type(data)}")
|
||||||
|
return ret
|
||||||
|
|
||||||
|
|
|
@ -6,18 +6,20 @@ import pprint
|
||||||
import random
|
import random
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Callable, Iterable, Optional, Union
|
from typing import Callable, Iterable, Optional, Union
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
from colossalai.engine import AMP_TYPE, NoPipelineSchedule, PipelineSchedule
|
from colossalai.engine import AMP_TYPE, NoPipelineSchedule, PipelineSchedule
|
||||||
|
from colossalai.engine import Engine
|
||||||
from colossalai.logging import get_global_dist_logger, init_global_dist_logger
|
from colossalai.logging import get_global_dist_logger, init_global_dist_logger
|
||||||
from colossalai.nn import DataParallelSampler
|
from colossalai.nn import DataParallelSampler
|
||||||
from colossalai.nn.model.base_model import BaseModel
|
from colossalai.nn.model.base_model import BaseModel
|
||||||
from .builder import (ModelInitializer, build_dataset, build_loss,
|
from .builder import (ModelInitializer, build_dataset, build_loss,
|
||||||
build_lr_scheduler, build_model, build_optimizer,
|
build_model, build_optimizer,
|
||||||
build_optimizer_wrapper)
|
build_optimizer_wrapper, build_schedule)
|
||||||
from .context import Config, ParallelMode
|
from .context import Config, ParallelMode
|
||||||
from .core import global_context as gpc
|
from .core import global_context as gpc
|
||||||
from .utils import get_current_device, sync_model_param_in_dp
|
from .utils import get_current_device, sync_model_param_in_dp
|
||||||
|
@ -182,7 +184,7 @@ def initialize(config: Union[str, dict] = None,
|
||||||
backend: str = None,
|
backend: str = None,
|
||||||
train_dataloader: Optional[Union[Iterable, Callable]] = None,
|
train_dataloader: Optional[Union[Iterable, Callable]] = None,
|
||||||
test_dataloader: Optional[Union[Iterable, Callable]] = None,
|
test_dataloader: Optional[Union[Iterable, Callable]] = None,
|
||||||
):
|
) -> Tuple[Engine, DataLoader, DataLoader]:
|
||||||
'''Core function that initializes distributed environment, logger, cudnn, data, model, loss function, optimizer, and lr_scheduler(their configs are in gpc.config).
|
'''Core function that initializes distributed environment, logger, cudnn, data, model, loss function, optimizer, and lr_scheduler(their configs are in gpc.config).
|
||||||
|
|
||||||
:param config: config file or config file path are both acceptable
|
:param config: config file or config file path are both acceptable
|
||||||
|
@ -201,7 +203,7 @@ def initialize(config: Union[str, dict] = None,
|
||||||
:type train_dataloader: Optional[Union[Iterable, Callable]], optional
|
:type train_dataloader: Optional[Union[Iterable, Callable]], optional
|
||||||
:param test_dataloader: If None, the config is used to build a dataloder; Else, it should be a dataloader object or a function with no arguments which can build a dataloader, defaults to None
|
:param test_dataloader: If None, the config is used to build a dataloder; Else, it should be a dataloader object or a function with no arguments which can build a dataloader, defaults to None
|
||||||
:type test_dataloader: Optional[Union[Iterable, Callable]], optional
|
:type test_dataloader: Optional[Union[Iterable, Callable]], optional
|
||||||
:return: (model, train_dataloader, test_dataloader, criterion, optimizer, schedule, lr_scheduler)
|
:return: (engine, train_dataloader, test_dataloader, criterion)
|
||||||
:rtype: tuple
|
:rtype: tuple
|
||||||
'''
|
'''
|
||||||
# initialize distributed environment
|
# initialize distributed environment
|
||||||
|
@ -337,21 +339,7 @@ def initialize(config: Union[str, dict] = None,
|
||||||
optimizer = build_optimizer_wrapper(fp16_cfg, optimizer)
|
optimizer = build_optimizer_wrapper(fp16_cfg, optimizer)
|
||||||
logger.info('Optimizer is created', ranks=[0])
|
logger.info('Optimizer is created', ranks=[0])
|
||||||
|
|
||||||
lr_scheduler = None
|
# build schedule and engine
|
||||||
if hasattr(gpc.config, 'lr_scheduler'):
|
|
||||||
if hasattr(gpc.config, 'num_steps'):
|
|
||||||
total_steps = gpc.config.num_steps
|
|
||||||
elif hasattr(gpc.config, 'num_epochs'):
|
|
||||||
total_steps = int(gpc.config.num_epochs * len(train_dataloader))
|
|
||||||
else:
|
|
||||||
raise Exception(
|
|
||||||
'Please specify training stopping criterion num_steps or num_epochs in your configuration.'
|
|
||||||
)
|
|
||||||
lr_scheduler = build_lr_scheduler(gpc.config.lr_scheduler, optimizer,
|
|
||||||
total_steps, len(train_dataloader))
|
|
||||||
logger.info('Learning rate scheduler is created', ranks=[0])
|
|
||||||
|
|
||||||
# pipeline or no pipeline schedule
|
|
||||||
if hasattr(gpc.config, 'fp16'):
|
if hasattr(gpc.config, 'fp16'):
|
||||||
amp_type = gpc.config.fp16.mode
|
amp_type = gpc.config.fp16.mode
|
||||||
amp_cfg = gpc.config.fp16.copy()
|
amp_cfg = gpc.config.fp16.copy()
|
||||||
|
@ -360,12 +348,32 @@ def initialize(config: Union[str, dict] = None,
|
||||||
amp_type = None
|
amp_type = None
|
||||||
amp_cfg = None
|
amp_cfg = None
|
||||||
|
|
||||||
if gpc.is_initialized(ParallelMode.PIPELINE) and gpc.get_world_size(ParallelMode.PIPELINE) > 1:
|
engine_cfg = gpc.config.get('engine', dict())
|
||||||
assert hasattr(gpc.config,
|
schedule_cfg = engine_cfg.pop('schedule', None)
|
||||||
'schedule'), "Config 'schedule' not found in your configuration file for pipeline parallel training"
|
|
||||||
|
schedule_type = None
|
||||||
|
if schedule_cfg is not None:
|
||||||
|
schedule_type = schedule_cfg.get('type', None)
|
||||||
|
|
||||||
|
if schedule_type is not None:
|
||||||
|
# run customized schedule
|
||||||
|
schedule_cfg['amp_type'] = amp_type
|
||||||
|
schedule_cfg['amp_config'] = amp_cfg
|
||||||
|
schedule = build_schedule(schedule_cfg)
|
||||||
|
elif gpc.is_initialized(ParallelMode.PIPELINE) and gpc.get_world_size(ParallelMode.PIPELINE) > 1:
|
||||||
|
assert schedule_cfg is not None, \
|
||||||
|
"Config 'engine.schedule' not found in your configuration file for pipeline parallel training"
|
||||||
schedule = PipelineSchedule(
|
schedule = PipelineSchedule(
|
||||||
amp_type=amp_type, amp_config=amp_cfg, **gpc.config.schedule.copy())
|
amp_type=amp_type, amp_config=amp_cfg, **schedule_cfg.copy())
|
||||||
else:
|
else:
|
||||||
schedule = NoPipelineSchedule(amp_type=amp_type, amp_config=amp_cfg)
|
schedule = NoPipelineSchedule(amp_type=amp_type, amp_config=amp_cfg)
|
||||||
|
|
||||||
return model, train_dataloader, test_dataloader, criterion, optimizer, schedule, lr_scheduler
|
engine = Engine(
|
||||||
|
model=model,
|
||||||
|
optimizer=optimizer,
|
||||||
|
criterion=criterion,
|
||||||
|
step_schedule=schedule,
|
||||||
|
**gpc.config.get('engine', dict())
|
||||||
|
)
|
||||||
|
|
||||||
|
return engine, train_dataloader, test_dataloader
|
||||||
|
|
|
@ -7,6 +7,7 @@ from torch import Tensor
|
||||||
from colossalai.context.parallel_mode import ParallelMode
|
from colossalai.context.parallel_mode import ParallelMode
|
||||||
from colossalai.core import global_context as gpc
|
from colossalai.core import global_context as gpc
|
||||||
from colossalai.utils import get_current_device
|
from colossalai.utils import get_current_device
|
||||||
|
from torch.cuda.amp import custom_bwd, custom_fwd
|
||||||
|
|
||||||
|
|
||||||
def matmul_2d(a,
|
def matmul_2d(a,
|
||||||
|
@ -60,6 +61,7 @@ class Matmul_AB_2D(torch.autograd.Function):
|
||||||
"""Matrix multiplication for :math:`C = AB`
|
"""Matrix multiplication for :math:`C = AB`
|
||||||
"""
|
"""
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@custom_fwd(cast_inputs=torch.float16)
|
||||||
def forward(ctx: Any,
|
def forward(ctx: Any,
|
||||||
A: Tensor,
|
A: Tensor,
|
||||||
B: Tensor,
|
B: Tensor,
|
||||||
|
@ -120,10 +122,11 @@ class Matmul_AB_2D(torch.autograd.Function):
|
||||||
return out
|
return out
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@custom_bwd
|
||||||
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
|
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
|
||||||
A, B = ctx.saved_tensors
|
A, B = ctx.saved_tensors
|
||||||
A_grad = Matmul_ABT_2D.forward(
|
with torch.no_grad():
|
||||||
None,
|
A_grad = Matmul_ABT_2D.apply(
|
||||||
output_grad, B,
|
output_grad, B,
|
||||||
ctx.summa_dim, ctx.A_shape,
|
ctx.summa_dim, ctx.A_shape,
|
||||||
ctx.row_rank, ctx.col_rank,
|
ctx.row_rank, ctx.col_rank,
|
||||||
|
@ -134,8 +137,7 @@ class Matmul_AB_2D(torch.autograd.Function):
|
||||||
ctx.pipeline_parallel_size,
|
ctx.pipeline_parallel_size,
|
||||||
ctx.tensor_parallel_size
|
ctx.tensor_parallel_size
|
||||||
)
|
)
|
||||||
B_grad = Matmul_ATB_2D.forward(
|
B_grad = Matmul_ATB_2D.apply(
|
||||||
None,
|
|
||||||
A, output_grad,
|
A, output_grad,
|
||||||
ctx.summa_dim, ctx.B_shape,
|
ctx.summa_dim, ctx.B_shape,
|
||||||
ctx.row_rank, ctx.col_rank,
|
ctx.row_rank, ctx.col_rank,
|
||||||
|
@ -153,6 +155,7 @@ class Matmul_ABT_2D(torch.autograd.Function):
|
||||||
"""Matrix multiplication for :math:`C = AB^T`
|
"""Matrix multiplication for :math:`C = AB^T`
|
||||||
"""
|
"""
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@custom_fwd(cast_inputs=torch.float16)
|
||||||
def forward(ctx: Any,
|
def forward(ctx: Any,
|
||||||
A: Tensor,
|
A: Tensor,
|
||||||
B: Tensor,
|
B: Tensor,
|
||||||
|
@ -214,10 +217,12 @@ class Matmul_ABT_2D(torch.autograd.Function):
|
||||||
return out
|
return out
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@custom_bwd
|
||||||
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
|
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
|
||||||
A, B = ctx.saved_tensors
|
A, B = ctx.saved_tensors
|
||||||
A_grad = Matmul_AB_2D.forward(
|
|
||||||
None,
|
with torch.no_grad():
|
||||||
|
A_grad = Matmul_AB_2D.apply(
|
||||||
output_grad, B,
|
output_grad, B,
|
||||||
ctx.summa_dim, ctx.A_shape,
|
ctx.summa_dim, ctx.A_shape,
|
||||||
ctx.row_rank, ctx.col_rank,
|
ctx.row_rank, ctx.col_rank,
|
||||||
|
@ -228,8 +233,7 @@ class Matmul_ABT_2D(torch.autograd.Function):
|
||||||
ctx.pipeline_parallel_size,
|
ctx.pipeline_parallel_size,
|
||||||
ctx.tensor_parallel_size
|
ctx.tensor_parallel_size
|
||||||
)
|
)
|
||||||
B_grad = Matmul_ATB_2D.forward(
|
B_grad = Matmul_ATB_2D.apply(
|
||||||
None,
|
|
||||||
output_grad, A,
|
output_grad, A,
|
||||||
ctx.summa_dim, ctx.B_shape,
|
ctx.summa_dim, ctx.B_shape,
|
||||||
ctx.row_rank, ctx.col_rank,
|
ctx.row_rank, ctx.col_rank,
|
||||||
|
@ -247,6 +251,7 @@ class Matmul_ATB_2D(torch.autograd.Function):
|
||||||
"""Matrix multiplication for :math:`C = A^TB`
|
"""Matrix multiplication for :math:`C = A^TB`
|
||||||
"""
|
"""
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@custom_fwd(cast_inputs=torch.float16)
|
||||||
def forward(ctx: Any,
|
def forward(ctx: Any,
|
||||||
A: Tensor,
|
A: Tensor,
|
||||||
B: Tensor,
|
B: Tensor,
|
||||||
|
@ -308,10 +313,12 @@ class Matmul_ATB_2D(torch.autograd.Function):
|
||||||
return out
|
return out
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@custom_bwd
|
||||||
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
|
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
|
||||||
A, B = ctx.saved_tensors
|
A, B = ctx.saved_tensors
|
||||||
A_grad = Matmul_ABT_2D.forward(
|
|
||||||
None,
|
with torch.no_grad():
|
||||||
|
A_grad = Matmul_ABT_2D.apply(
|
||||||
B, output_grad,
|
B, output_grad,
|
||||||
ctx.summa_dim, ctx.A_shape,
|
ctx.summa_dim, ctx.A_shape,
|
||||||
ctx.row_rank, ctx.col_rank,
|
ctx.row_rank, ctx.col_rank,
|
||||||
|
@ -322,8 +329,7 @@ class Matmul_ATB_2D(torch.autograd.Function):
|
||||||
ctx.pipeline_parallel_size,
|
ctx.pipeline_parallel_size,
|
||||||
ctx.tensor_parallel_size
|
ctx.tensor_parallel_size
|
||||||
)
|
)
|
||||||
B_grad = Matmul_AB_2D.forward(
|
B_grad = Matmul_AB_2D.apply(
|
||||||
None,
|
|
||||||
A, output_grad,
|
A, output_grad,
|
||||||
ctx.summa_dim, ctx.B_shape,
|
ctx.summa_dim, ctx.B_shape,
|
||||||
ctx.row_rank, ctx.col_rank,
|
ctx.row_rank, ctx.col_rank,
|
||||||
|
@ -341,6 +347,7 @@ class Add_Bias_2D(torch.autograd.Function):
|
||||||
"""Matrix add bias: :math:`C = A + b`
|
"""Matrix add bias: :math:`C = A + b`
|
||||||
"""
|
"""
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@custom_fwd(cast_inputs=torch.float16)
|
||||||
def forward(ctx: Any,
|
def forward(ctx: Any,
|
||||||
input: Tensor,
|
input: Tensor,
|
||||||
bias: Tensor,
|
bias: Tensor,
|
||||||
|
@ -384,6 +391,7 @@ class Add_Bias_2D(torch.autograd.Function):
|
||||||
return output
|
return output
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@custom_bwd
|
||||||
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
|
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
|
||||||
row_rank = ctx.row_rank
|
row_rank = ctx.row_rank
|
||||||
col_rank = ctx.col_rank
|
col_rank = ctx.col_rank
|
||||||
|
@ -423,6 +431,7 @@ class Add_Bias_2D(torch.autograd.Function):
|
||||||
class _LayerNorm_2D(torch.autograd.Function):
|
class _LayerNorm_2D(torch.autograd.Function):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@custom_fwd(cast_inputs=torch.float32)
|
||||||
def forward(ctx: Any,
|
def forward(ctx: Any,
|
||||||
input: Tensor,
|
input: Tensor,
|
||||||
E_x: Tensor,
|
E_x: Tensor,
|
||||||
|
@ -440,6 +449,7 @@ class _LayerNorm_2D(torch.autograd.Function):
|
||||||
return output
|
return output
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@custom_bwd
|
||||||
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
|
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
|
||||||
row_parallel_mode = ctx.row_parallel_mode
|
row_parallel_mode = ctx.row_parallel_mode
|
||||||
col_parallel_mode = ctx.col_parallel_mode
|
col_parallel_mode = ctx.col_parallel_mode
|
||||||
|
@ -492,6 +502,7 @@ class _LayerNorm_2D(torch.autograd.Function):
|
||||||
class _ViT_Split_Input_2D(torch.autograd.Function):
|
class _ViT_Split_Input_2D(torch.autograd.Function):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@custom_fwd(cast_inputs=torch.float16)
|
||||||
def forward(ctx: Any,
|
def forward(ctx: Any,
|
||||||
inputs: Tensor,
|
inputs: Tensor,
|
||||||
batch_size: int,
|
batch_size: int,
|
||||||
|
@ -509,6 +520,7 @@ class _ViT_Split_Input_2D(torch.autograd.Function):
|
||||||
return output
|
return output
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@custom_bwd
|
||||||
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
|
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
|
||||||
# output_grad: [b/q, s, h/q]
|
# output_grad: [b/q, s, h/q]
|
||||||
# grads: [b, s, h/q]
|
# grads: [b, s, h/q]
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
from .cosine import CosineAnnealingLR, CosineAnnealingWarmupLR, FlatAnnealingLR, FlatAnnealingWarmupLR
|
from .cosine import CosineAnnealingLR, CosineAnnealingWarmupLR, FlatAnnealingLR, FlatAnnealingWarmupLR
|
||||||
from .linear import LinearWarmupLR, LinearWarmupDecay
|
from .linear import LinearWarmupLR
|
||||||
from .multistep import MultiStepLR, MultiStepWarmupLR
|
from .multistep import MultiStepLR, MultiStepWarmupLR
|
||||||
from .onecycle import OneCycleLR
|
from .onecycle import OneCycleLR
|
||||||
from .poly import PolynomialLR, PolynomialWarmupLR
|
from .poly import PolynomialLR, PolynomialWarmupLR
|
||||||
|
|
|
@ -66,11 +66,10 @@ class CosineAnnealingWarmupLR(WarmupScheduler):
|
||||||
:type last_epoch: int, optional
|
:type last_epoch: int, optional
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, optimizer, total_steps: int, warmup_steps: int = 0, eta_min: int = 0, last_epoch: int = -1,
|
def __init__(self, optimizer, total_steps: int, warmup_steps: int = 0, eta_min: int = 0, last_epoch: int = -1):
|
||||||
**kwargs):
|
|
||||||
base_scheduler = _CosineAnnealingLR(
|
base_scheduler = _CosineAnnealingLR(
|
||||||
optimizer, total_steps - warmup_steps, eta_min=eta_min)
|
optimizer, total_steps - warmup_steps, eta_min=eta_min, last_epoch=last_epoch)
|
||||||
super().__init__(optimizer, warmup_steps, base_scheduler, last_epoch=last_epoch)
|
super().__init__(optimizer, warmup_steps, base_scheduler)
|
||||||
|
|
||||||
|
|
||||||
@LR_SCHEDULERS.register_module
|
@LR_SCHEDULERS.register_module
|
||||||
|
|
|
@ -66,11 +66,8 @@ class WarmupScheduler(_LRScheduler):
|
||||||
:param last_epoch: The index of last epoch, defaults to -1
|
:param last_epoch: The index of last epoch, defaults to -1
|
||||||
:type last_epoch: int, optional
|
:type last_epoch: int, optional
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, optimizer, warmup_epochs, after_scheduler, last_epoch=-1):
|
def __init__(self, optimizer, warmup_epochs, after_scheduler, last_epoch=-1):
|
||||||
if warmup_epochs < 0:
|
self.warmup_epochs = int(warmup_epochs)
|
||||||
raise ValueError(f'warmup_epochs must >= 0, got {warmup_epochs}')
|
|
||||||
self.warmup_epochs = warmup_epochs
|
|
||||||
self.after_scheduler = after_scheduler
|
self.after_scheduler = after_scheduler
|
||||||
self.finished = False
|
self.finished = False
|
||||||
super().__init__(optimizer, last_epoch)
|
super().__init__(optimizer, last_epoch)
|
||||||
|
@ -79,14 +76,10 @@ class WarmupScheduler(_LRScheduler):
|
||||||
if self.last_epoch >= self.warmup_epochs:
|
if self.last_epoch >= self.warmup_epochs:
|
||||||
if not self.finished:
|
if not self.finished:
|
||||||
self.after_scheduler.base_lrs = self.base_lrs
|
self.after_scheduler.base_lrs = self.base_lrs
|
||||||
# reset lr to base_lr
|
|
||||||
for group, base_lr in zip(self.optimizer.param_groups, self.base_lrs):
|
|
||||||
group['lr'] = base_lr
|
|
||||||
self.finished = True
|
self.finished = True
|
||||||
with _enable_get_lr_call(self.after_scheduler):
|
|
||||||
return self.after_scheduler.get_lr()
|
return self.after_scheduler.get_lr()
|
||||||
|
|
||||||
return [(self.last_epoch + 1) / (self.warmup_epochs + 1) * lr for lr in self.base_lrs]
|
return [(self.last_epoch + 1) / self.warmup_epochs * lr for lr in self.base_lrs]
|
||||||
|
|
||||||
def step(self, epoch=None):
|
def step(self, epoch=None):
|
||||||
if self.finished:
|
if self.finished:
|
||||||
|
|
|
@ -28,18 +28,3 @@ class LinearWarmupLR(_LRScheduler):
|
||||||
else:
|
else:
|
||||||
return [(self.total_steps - self.last_epoch) / (self.total_steps - self.warmup_steps) * lr for lr in
|
return [(self.total_steps - self.last_epoch) / (self.total_steps - self.warmup_steps) * lr for lr in
|
||||||
self.base_lrs]
|
self.base_lrs]
|
||||||
|
|
||||||
|
|
||||||
@LR_SCHEDULERS.register_module
|
|
||||||
class LinearWarmupDecay(_LRScheduler):
|
|
||||||
def __init__(self, optimizer, total_steps: int, warmup_steps: int = 0, last_epoch: int = -1, **kwargs):
|
|
||||||
self.warmup_steps = int(warmup_steps)
|
|
||||||
self.total_steps = total_steps
|
|
||||||
super().__init__(optimizer, last_epoch=last_epoch)
|
|
||||||
|
|
||||||
def get_lr(self):
|
|
||||||
if self.last_epoch < self.warmup_steps:
|
|
||||||
return [(self.last_epoch + 1) / self.warmup_steps * lr for lr in self.base_lrs]
|
|
||||||
else:
|
|
||||||
return [(self.total_steps - self.last_epoch - 1) / (self.total_steps - self.warmup_steps) * lr for lr in
|
|
||||||
self.base_lrs]
|
|
||||||
|
|
|
@ -27,12 +27,7 @@ class MultiStepLR(_MultiStepLR):
|
||||||
:type last_epoch: int, optional
|
:type last_epoch: int, optional
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, optimizer, total_steps: int, milestones: List[int] = None, gamma: float = 0.1,
|
def __init__(self, optimizer, total_steps: int, milestones: List[int] = None, gamma: float = 0.1, last_epoch: int = -1, **kwargs):
|
||||||
num_steps_per_epoch: int = -1, last_epoch: int = -1, **kwargs):
|
|
||||||
if num_steps_per_epoch <= 0:
|
|
||||||
raise ValueError(
|
|
||||||
f'num_steps_per_epoch must > 0, got {num_steps_per_epoch}')
|
|
||||||
milestones = [v * num_steps_per_epoch for v in milestones]
|
|
||||||
super().__init__(optimizer, milestones, gamma=gamma, last_epoch=last_epoch)
|
super().__init__(optimizer, milestones, gamma=gamma, last_epoch=last_epoch)
|
||||||
|
|
||||||
|
|
||||||
|
@ -57,14 +52,11 @@ class MultiStepWarmupLR(WarmupScheduler):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, optimizer, total_steps: int, warmup_steps: int = 0, milestones: List[int] = None,
|
def __init__(self, optimizer, total_steps: int, warmup_steps: int = 0, milestones: List[int] = None,
|
||||||
gamma: float = 0.1, num_steps_per_epoch: int = -1, last_epoch: int = -1, **kwargs):
|
gamma: float = 0.1, last_epoch: int = -1, **kwargs):
|
||||||
if len(milestones) == 0:
|
if len(milestones) == 0:
|
||||||
raise ValueError('milestones cannot be empty')
|
raise ValueError('milestones cannot be empty')
|
||||||
if num_steps_per_epoch <= 0:
|
milestones = [
|
||||||
raise ValueError(
|
v - warmup_steps for v in milestones if v >= warmup_steps]
|
||||||
f'num_steps_per_epoch must > 0, got {num_steps_per_epoch}')
|
|
||||||
milestones = [v * num_steps_per_epoch - warmup_steps for v in milestones if v *
|
|
||||||
num_steps_per_epoch >= warmup_steps]
|
|
||||||
base_scheduler = _MultiStepLR(optimizer, milestones=milestones,
|
base_scheduler = _MultiStepLR(optimizer, milestones=milestones,
|
||||||
gamma=gamma)
|
gamma=gamma)
|
||||||
super().__init__(optimizer, warmup_steps, base_scheduler, last_epoch=last_epoch)
|
super().__init__(optimizer, warmup_steps, base_scheduler, last_epoch=last_epoch)
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
from torch.optim.lr_scheduler import LambdaLR as _LambdaLR
|
from torch.optim.lr_scheduler import LambdaLR as _LambdaLR
|
||||||
from torch.optim.lr_scheduler import MultiplicativeLR as _MultiplicativeLR
|
from torch.optim.lr_scheduler import MultiplicativeLR as _MultiplicativeLR
|
||||||
from torch.optim.lr_scheduler import StepLR as _StepLR
|
from torch.optim.lr_scheduler import StepLR as _StepLR
|
||||||
from torch.optim.lr_scheduler import _LRScheduler
|
from torch.optim.lr_scheduler import ExponentialLR as _ExponentialLR
|
||||||
|
|
||||||
from colossalai.registry import LR_SCHEDULERS
|
from colossalai.registry import LR_SCHEDULERS
|
||||||
|
|
||||||
|
@ -25,11 +25,8 @@ class LambdaLR(_LambdaLR):
|
||||||
:type last_epoch: int, optional
|
:type last_epoch: int, optional
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, optimizer, total_steps, lr_lambda=None, num_steps_per_epoch: int = -1,
|
def __init__(self, optimizer, total_steps, lr_lambda=None, last_epoch: int = -1) -> None:
|
||||||
last_epoch: int = -1) -> None:
|
super().__init__(optimizer, lr_lambda, last_epoch=last_epoch)
|
||||||
def func(step): return lr_lambda(step // num_steps_per_epoch)
|
|
||||||
|
|
||||||
super().__init__(optimizer, func, last_epoch=last_epoch)
|
|
||||||
|
|
||||||
|
|
||||||
@LR_SCHEDULERS.register_module
|
@LR_SCHEDULERS.register_module
|
||||||
|
@ -51,11 +48,8 @@ class MultiplicativeLR(_MultiplicativeLR):
|
||||||
:type last_epoch: int, optional
|
:type last_epoch: int, optional
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, optimizer, total_steps, lr_lambda=None, num_steps_per_epoch: int = -1,
|
def __init__(self, optimizer, total_steps, lr_lambda=None, last_epoch: int = -1) -> None:
|
||||||
last_epoch: int = -1) -> None:
|
super().__init__(optimizer, lr_lambda, last_epoch=last_epoch)
|
||||||
def func(step): return lr_lambda(step // num_steps_per_epoch)
|
|
||||||
|
|
||||||
super().__init__(optimizer, func, last_epoch=last_epoch)
|
|
||||||
|
|
||||||
|
|
||||||
@LR_SCHEDULERS.register_module
|
@LR_SCHEDULERS.register_module
|
||||||
|
@ -79,14 +73,13 @@ class StepLR(_StepLR):
|
||||||
:type last_epoch: int, optional
|
:type last_epoch: int, optional
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, optimizer, total_steps, step_size: int = 1, gamma: float = 0.1, num_steps_per_epoch: int = -1,
|
def __init__(self, optimizer, total_steps, step_size: int = 1, gamma: float = 0.1, last_epoch: int = -1) -> None:
|
||||||
last_epoch: int = -1) -> None:
|
super().__init__(optimizer, step_size,
|
||||||
super().__init__(optimizer, step_size * num_steps_per_epoch,
|
|
||||||
gamma=gamma, last_epoch=last_epoch)
|
gamma=gamma, last_epoch=last_epoch)
|
||||||
|
|
||||||
|
|
||||||
@LR_SCHEDULERS.register_module
|
@LR_SCHEDULERS.register_module
|
||||||
class ExponentialLR(_LRScheduler):
|
class ExponentialLR(_ExponentialLR):
|
||||||
"""Decays the learning rate of each parameter group by gamma every epoch.
|
"""Decays the learning rate of each parameter group by gamma every epoch.
|
||||||
When last_epoch=-1, sets initial lr as lr
|
When last_epoch=-1, sets initial lr as lr
|
||||||
|
|
||||||
|
@ -102,21 +95,6 @@ class ExponentialLR(_LRScheduler):
|
||||||
:type last_epoch: int, optional
|
:type last_epoch: int, optional
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, optimizer, total_steps, gamma: float = 1.0, num_steps_per_epoch: int = -1,
|
def __init__(self, optimizer, total_steps, gamma: float = 1.0,
|
||||||
last_epoch: int = -1) -> None:
|
last_epoch: int = -1) -> None:
|
||||||
self.gamma = gamma
|
super().__init__(optimizer, gamma, last_epoch=last_epoch)
|
||||||
self.num_steps_per_epoch = num_steps_per_epoch
|
|
||||||
super().__init__(optimizer, last_epoch=last_epoch)
|
|
||||||
|
|
||||||
def get_lr(self):
|
|
||||||
if self.last_epoch == 0:
|
|
||||||
return self.base_lrs
|
|
||||||
elif (self.last_epoch + 1) % self.num_steps_per_epoch == 0:
|
|
||||||
return [group['lr'] * self.gamma
|
|
||||||
for group in self.optimizer.param_groups]
|
|
||||||
return [group['lr']
|
|
||||||
for group in self.optimizer.param_groups]
|
|
||||||
|
|
||||||
def _get_closed_form_lr(self):
|
|
||||||
return [base_lr * self.gamma ** (self.last_epoch // self.num_steps_per_epoch)
|
|
||||||
for base_lr in self.base_lrs]
|
|
||||||
|
|
|
@ -106,7 +106,7 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
|
||||||
tensor_parallel_norm = _calc_lp(tensor_parallel_grads, norm_type)
|
tensor_parallel_norm = _calc_lp(tensor_parallel_grads, norm_type)
|
||||||
no_tensor_parallel_grads = _calc_lp(
|
no_tensor_parallel_grads = _calc_lp(
|
||||||
no_tensor_parallel_grads, norm_type)
|
no_tensor_parallel_grads, norm_type)
|
||||||
if gpc.is_initialized(ParallelMode.TENSOR):
|
if gpc.is_initialized(ParallelMode.TENSOR) and len(tensor_parallel_grads) > 0:
|
||||||
# Sum across all model-parallel GPUs.
|
# Sum across all model-parallel GPUs.
|
||||||
torch.distributed.all_reduce(tensor_parallel_norm,
|
torch.distributed.all_reduce(tensor_parallel_norm,
|
||||||
op=torch.distributed.ReduceOp.SUM,
|
op=torch.distributed.ReduceOp.SUM,
|
||||||
|
|
|
@ -6,6 +6,7 @@ import math
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from deepspeed.git_version_info import version
|
from deepspeed.git_version_info import version
|
||||||
from deepspeed.moe.utils import is_moe_param
|
from deepspeed.moe.utils import is_moe_param
|
||||||
|
@ -13,7 +14,7 @@ try:
|
||||||
from deepspeed.ops.op_builder import UtilsBuilder
|
from deepspeed.ops.op_builder import UtilsBuilder
|
||||||
from deepspeed.runtime.zero.config import ZERO_OPTIMIZATION_GRADIENTS
|
from deepspeed.runtime.zero.config import ZERO_OPTIMIZATION_GRADIENTS
|
||||||
except ImportError:
|
except ImportError:
|
||||||
print('DeepSpeed is required if you want to use ZeRO.')
|
pass
|
||||||
from packaging import version as pkg_version
|
from packaging import version as pkg_version
|
||||||
from torch._six import inf
|
from torch._six import inf
|
||||||
from torch.distributed.distributed_c10d import _get_global_rank
|
from torch.distributed.distributed_c10d import _get_global_rank
|
||||||
|
|
|
@ -21,7 +21,7 @@ try:
|
||||||
from deepspeed.runtime.zero.partition_parameters import *
|
from deepspeed.runtime.zero.partition_parameters import *
|
||||||
from deepspeed.runtime.zero.partition_parameters import _init_external_params
|
from deepspeed.runtime.zero.partition_parameters import _init_external_params
|
||||||
except ImportError:
|
except ImportError:
|
||||||
print('DeepSpeed is required if you want to use ZeRO.')
|
pass
|
||||||
|
|
||||||
from torch._six import inf
|
from torch._six import inf
|
||||||
from torch.distributed.distributed_c10d import _get_global_rank
|
from torch.distributed.distributed_c10d import _get_global_rank
|
||||||
|
|
|
@ -20,3 +20,4 @@ TRANSFORMS = Registry('transforms', third_party_library=[transforms])
|
||||||
PIPE_ALLOC_POLICY = Registry('pipeline_allocation_policy')
|
PIPE_ALLOC_POLICY = Registry('pipeline_allocation_policy')
|
||||||
SAMPLERS = Registry('samplers')
|
SAMPLERS = Registry('samplers')
|
||||||
LR_SCHEDULERS = Registry('lr_schedulers')
|
LR_SCHEDULERS = Registry('lr_schedulers')
|
||||||
|
SCHEDULE = Registry('schedules')
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
from ._trainer import Trainer
|
from ._trainer import Trainer
|
||||||
from .hooks import *
|
from .hooks import *
|
||||||
from .metric import Loss, Accuracy2D, Accuracy3D, Accuracy2p5D
|
from .metric import Loss, Accuracy2D, Accuracy3D, Accuracy2p5D, LearningRate
|
||||||
|
|
||||||
__all__ = ['Trainer', 'Loss', 'Accuracy3D', 'Accuracy2D', 'Accuracy2p5D']
|
__all__ = ['Trainer', 'Loss', 'Accuracy3D', 'Accuracy2D', 'Accuracy2p5D', 'LearningRate']
|
||||||
|
|
|
@ -1,7 +1,6 @@
|
||||||
#!/usr/bin/env python
|
#!/usr/bin/env python
|
||||||
# -*- encoding: utf-8 -*-
|
# -*- encoding: utf-8 -*-
|
||||||
|
|
||||||
from typing import Optional
|
|
||||||
from typing import Union, List
|
from typing import Union, List
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
@ -10,12 +9,11 @@ from torch.utils.data import DataLoader
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from colossalai.builder import build_hooks
|
from colossalai.builder import build_hooks
|
||||||
from colossalai.checkpointing import save_checkpoint, load_checkpoint, get_checkpoint_path
|
|
||||||
from colossalai.context import Config
|
|
||||||
from colossalai.engine import Engine
|
from colossalai.engine import Engine
|
||||||
from colossalai.logging import get_global_dist_logger
|
from colossalai.logging import get_global_dist_logger
|
||||||
from colossalai.utils import get_global_multitimer, is_dp_rank_0, is_tp_rank_0, is_no_pp_or_last_stage
|
|
||||||
from colossalai.nn.data import DataParallelSampler
|
from colossalai.nn.data import DataParallelSampler
|
||||||
|
from colossalai.utils import MultiTimer
|
||||||
|
from colossalai.utils import is_dp_rank_0, is_tp_rank_0, is_no_pp_or_last_stage
|
||||||
|
|
||||||
|
|
||||||
class Trainer:
|
class Trainer:
|
||||||
|
@ -30,43 +28,31 @@ class Trainer:
|
||||||
:type hoooks_cfg: Config, optional
|
:type hoooks_cfg: Config, optional
|
||||||
:type verbose: bool, optional
|
:type verbose: bool, optional
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
engine: Engine,
|
engine: Engine,
|
||||||
hooks_cfg: Optional[Config] = None,
|
verbose: bool = False,
|
||||||
verbose: bool = False):
|
timer: MultiTimer = None):
|
||||||
# training-ralated params
|
# training-ralated params
|
||||||
self._engine = engine
|
self._engine = engine
|
||||||
self._max_epochs = float('inf')
|
self._max_epochs = 0
|
||||||
self._max_steps = float('inf')
|
|
||||||
self._cur_epoch = 0
|
self._cur_epoch = 0
|
||||||
|
self._max_steps = 0
|
||||||
self._cur_step = 0
|
self._cur_step = 0
|
||||||
|
self._steps_per_epoch = 0
|
||||||
# data-related params
|
|
||||||
self._train_dataloader = None
|
|
||||||
self._test_dataloader = None
|
|
||||||
|
|
||||||
# misc params
|
# misc params
|
||||||
self._display_progress = False
|
|
||||||
self._logger = get_global_dist_logger()
|
self._logger = get_global_dist_logger()
|
||||||
self._verbose = verbose
|
self._verbose = verbose
|
||||||
|
|
||||||
# hooks can store states in this dict, and could be consumed by other hooks
|
# hooks can store states in this dict, and could be consumed by other hooks
|
||||||
self.states = {}
|
self.states = dict()
|
||||||
|
|
||||||
# build hooks
|
# build hooks
|
||||||
self.hooks = list()
|
self.hooks = list()
|
||||||
if hooks_cfg is not None:
|
|
||||||
for cfg in hooks_cfg:
|
|
||||||
hook = build_hooks(cfg, self)
|
|
||||||
self.hooks.append(hook)
|
|
||||||
self.hooks.sort(key=lambda hook: hook.priority)
|
|
||||||
if self._verbose:
|
|
||||||
for hook in self.hooks:
|
|
||||||
self._logger.info(
|
|
||||||
f'build {hook.__class__.__name__} for train, priority = {hook.priority}', ranks=[0])
|
|
||||||
|
|
||||||
# timer
|
# multi-timer for time benchmarking
|
||||||
self._timer = get_global_multitimer()
|
self._timer = timer
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def cur_epoch(self):
|
def cur_epoch(self):
|
||||||
|
@ -74,13 +60,65 @@ class Trainer:
|
||||||
"""
|
"""
|
||||||
return self._cur_epoch
|
return self._cur_epoch
|
||||||
|
|
||||||
|
@cur_epoch.setter
|
||||||
|
def cur_epoch(self, epoch: int):
|
||||||
|
"""Set how many epochs have been processed.
|
||||||
|
"""
|
||||||
|
# allow setter for training resumption
|
||||||
|
self._cur_epoch = epoch
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def cur_step(self):
|
def cur_step(self):
|
||||||
"""Returns how many iteration steps have been processed.
|
"""Returns how many iteration steps have been processed.
|
||||||
"""
|
"""
|
||||||
return self._cur_step
|
return self._cur_step
|
||||||
|
|
||||||
def call_hooks(self, func, output=None):
|
@property
|
||||||
|
def max_epochs(self):
|
||||||
|
return self._max_epochs
|
||||||
|
|
||||||
|
@property
|
||||||
|
def max_steps(self):
|
||||||
|
return self._max_steps
|
||||||
|
|
||||||
|
@property
|
||||||
|
def steps_per_epoch(self):
|
||||||
|
return self._steps_per_epoch
|
||||||
|
|
||||||
|
@property
|
||||||
|
def engine(self):
|
||||||
|
return self._engine
|
||||||
|
|
||||||
|
@engine.setter
|
||||||
|
def engine(self, engine_: Engine):
|
||||||
|
self._engine = engine_
|
||||||
|
|
||||||
|
def _set_current_step(self, epoch: int):
|
||||||
|
"""Sets current step number.
|
||||||
|
|
||||||
|
:param epoch: Step number to be set
|
||||||
|
:type epoch: int
|
||||||
|
"""
|
||||||
|
self._cur_step = epoch * self._steps_per_epoch
|
||||||
|
|
||||||
|
def _call_timer(self, action: str, item: str, *args, **kwargs) -> None:
|
||||||
|
"""Call timer funciton with a given timer name.
|
||||||
|
|
||||||
|
:param action: Function to be called on timer
|
||||||
|
:type action: str
|
||||||
|
:param item: Name of the timer
|
||||||
|
:type item: str
|
||||||
|
"""
|
||||||
|
|
||||||
|
if self._timer is not None:
|
||||||
|
getattr(self._timer, action)(item, *args, **kwargs)
|
||||||
|
|
||||||
|
def _reset_states(self) -> None:
|
||||||
|
"""Clear trainer states
|
||||||
|
"""
|
||||||
|
self.states = dict()
|
||||||
|
|
||||||
|
def _call_hooks(self, func, output=None):
|
||||||
"""Calls specific hooks in the current time point.
|
"""Calls specific hooks in the current time point.
|
||||||
|
|
||||||
:param func: A string represents the time point
|
:param func: A string represents the time point
|
||||||
|
@ -95,161 +133,186 @@ class Trainer:
|
||||||
else:
|
else:
|
||||||
getattr(hook, func)(*output)
|
getattr(hook, func)(*output)
|
||||||
|
|
||||||
def exceed_max_step(self):
|
@staticmethod
|
||||||
"""Checks whether the trainer exceeds the maximum number of runnning iterations.
|
def _should_display_progress(display_progress: bool):
|
||||||
|
""" Only display progress on DP rank 0, TP rank 0 and PP last rank
|
||||||
"""
|
"""
|
||||||
return self._cur_step >= self._max_steps
|
return display_progress and is_dp_rank_0() and is_tp_rank_0() and is_no_pp_or_last_stage()
|
||||||
|
|
||||||
def set_epoch(self, epoch):
|
def _train_epoch(self,
|
||||||
"""Sets current epoch number.
|
train_dataloader: DataLoader,
|
||||||
|
epoch: int = None,
|
||||||
:param epoch: Epoch number to be set
|
display_progress: bool = False):
|
||||||
:type epoch: int
|
|
||||||
"""
|
|
||||||
self._cur_epoch = epoch
|
|
||||||
|
|
||||||
def _recover_steps(self):
|
|
||||||
step = self.cur_step * self._engine.schedule.num_steps
|
|
||||||
self._cur_step = step
|
|
||||||
|
|
||||||
def _set_display_progress(self, display_progress: bool):
|
|
||||||
self._display_progress = display_progress and is_dp_rank_0(
|
|
||||||
) and is_tp_rank_0() and is_no_pp_or_last_stage()
|
|
||||||
|
|
||||||
def _train_epoch(self, epoch: int = None):
|
|
||||||
# set sampler epoch
|
# set sampler epoch
|
||||||
if epoch is not None and \
|
if epoch is not None and \
|
||||||
hasattr(self._engine.train_dataloader, 'sampler') and \
|
hasattr(train_dataloader, 'sampler') and \
|
||||||
isinstance(self._engine.train_dataloader.sampler, DataParallelSampler):
|
isinstance(train_dataloader.sampler, DataParallelSampler):
|
||||||
self._engine.train_dataloader.sampler.set_epoch(epoch)
|
train_dataloader.sampler.set_epoch(epoch)
|
||||||
|
|
||||||
|
# set training state
|
||||||
self._engine.train()
|
self._engine.train()
|
||||||
|
data_iter = iter(train_dataloader)
|
||||||
progress = range(self._engine.schedule.num_steps)
|
progress = range(self._steps_per_epoch)
|
||||||
if self._display_progress:
|
if display_progress:
|
||||||
if epoch is None:
|
if epoch is None:
|
||||||
progress = tqdm(progress, desc='[Train]')
|
progress = tqdm(progress, desc='[Train]')
|
||||||
else:
|
else:
|
||||||
progress = tqdm(progress, desc=f'[Epoch {epoch} train]')
|
progress = tqdm(progress, desc=f'[Epoch {epoch} train]')
|
||||||
|
|
||||||
# train 1 epoch
|
# train 1 epoch
|
||||||
self.call_hooks('before_train_epoch')
|
self._call_hooks('before_train_epoch')
|
||||||
self._timer.start('train-epoch')
|
self._call_timer(action='start', item='train-epoch')
|
||||||
for _ in progress:
|
for i in progress:
|
||||||
|
self._call_hooks('before_train_iter')
|
||||||
|
self._call_timer(action='start', item='train-step')
|
||||||
|
|
||||||
|
if i == self._steps_per_epoch - 1:
|
||||||
|
is_last_iteration = True
|
||||||
|
else:
|
||||||
|
is_last_iteration = False
|
||||||
|
|
||||||
|
# run 1 training step
|
||||||
|
logits, label, loss = self._engine.step(data_iter, is_last_iteration)
|
||||||
|
self._call_timer(action='stop', item='train-step', keep_in_history=True)
|
||||||
|
self._call_hooks('after_train_iter', output=(logits, label, loss))
|
||||||
|
|
||||||
self._cur_step += 1
|
self._cur_step += 1
|
||||||
|
|
||||||
self.call_hooks('before_train_iter')
|
|
||||||
self._timer.start('train-step')
|
|
||||||
logits, label, loss = self._engine.step()
|
|
||||||
self._timer.stop('train-step', keep_in_history=True)
|
|
||||||
self.call_hooks('after_train_iter', output=(logits, label, loss))
|
|
||||||
|
|
||||||
if self.exceed_max_step():
|
|
||||||
# stop when max iter is reached
|
# stop when max iter is reached
|
||||||
|
if self._exceed_max_step():
|
||||||
break
|
break
|
||||||
self._timer.stop('train-epoch', keep_in_history=True)
|
|
||||||
self.call_hooks('after_train_epoch')
|
self._call_timer(action='stop', item='train-epoch', keep_in_history=True)
|
||||||
self._timer.reset('train-step')
|
self._call_hooks('after_train_epoch')
|
||||||
|
self._call_timer(action='reset', item='train-step')
|
||||||
|
|
||||||
def _eval(self,
|
def _eval(self,
|
||||||
|
test_dataloader: DataLoader,
|
||||||
epoch: int = None,
|
epoch: int = None,
|
||||||
return_loss: bool = True):
|
display_progress: bool = False):
|
||||||
# switch engine status
|
# switch engine status
|
||||||
self._engine.eval()
|
self._engine.eval()
|
||||||
|
|
||||||
self.call_hooks('before_test')
|
data_iter = iter(test_dataloader)
|
||||||
|
num_steps = len(test_dataloader)
|
||||||
|
|
||||||
|
self._call_hooks('before_test')
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
# prepare progress bar
|
# prepare progress bar
|
||||||
progress = range(self._engine.schedule.num_steps)
|
progress = range(num_steps)
|
||||||
if self._display_progress:
|
if display_progress:
|
||||||
desc = 'Evaluation'
|
desc = 'Evaluation'
|
||||||
if epoch is not None:
|
if epoch is not None:
|
||||||
desc = '[Epoch %d val]' % epoch
|
desc = '[Epoch %d val]' % epoch
|
||||||
progress = tqdm(progress, desc=desc)
|
progress = tqdm(progress, desc=desc)
|
||||||
|
|
||||||
self.call_hooks('before_test_epoch')
|
self._call_hooks('before_test_epoch')
|
||||||
self._timer.start('test-epoch')
|
self._call_timer(action='start', item='test-epoch')
|
||||||
for _ in progress:
|
for _ in progress:
|
||||||
self.call_hooks('before_test_iter')
|
self._call_hooks('before_test_iter')
|
||||||
self._timer.start('test-step')
|
self._call_timer(action='start', item='test-step')
|
||||||
logits, label, loss = self._engine.step(
|
logits, label, loss = self._engine.step(data_iter, return_loss=True)
|
||||||
return_loss=return_loss)
|
self._call_timer(action='stop', item='test-step', keep_in_history=True)
|
||||||
self._timer.stop('test-step', keep_in_history=True)
|
self._call_hooks('after_test_iter',
|
||||||
self.call_hooks('after_test_iter',
|
|
||||||
output=(logits, label, loss))
|
output=(logits, label, loss))
|
||||||
self._timer.stop('test-epoch', keep_in_history=True)
|
self._call_timer(action='stop', item='test-epoch', keep_in_history=True)
|
||||||
self.call_hooks('after_test_epoch')
|
self._call_hooks('after_test_epoch')
|
||||||
self.call_hooks('after_test')
|
self._call_hooks('after_test')
|
||||||
self._timer.reset('test-step')
|
self._call_timer(action='reset', item='test-step')
|
||||||
self._timer.reset('test-epoch')
|
self._call_timer(action='reset', item='test-epoch')
|
||||||
|
|
||||||
|
def _exceed_max_step(self):
|
||||||
|
return self._max_steps is not None and self._cur_step > self._max_steps
|
||||||
|
|
||||||
def fit(self,
|
def fit(self,
|
||||||
train_dataloader: DataLoader,
|
train_dataloader: DataLoader,
|
||||||
test_dataloader: DataLoader = None,
|
epochs: int,
|
||||||
max_epochs: int = None,
|
|
||||||
max_steps: int = None,
|
max_steps: int = None,
|
||||||
|
test_dataloader: DataLoader = None,
|
||||||
test_interval: int = 1,
|
test_interval: int = 1,
|
||||||
display_progress: bool = False):
|
hooks_cfg: dict = None,
|
||||||
|
display_progress: bool = False,
|
||||||
|
):
|
||||||
"""Trains the model to fit training data.
|
"""Trains the model to fit training data.
|
||||||
|
|
||||||
:param train_dataloader: DataLoader in training
|
:param train_dataloader: DataLoader in training
|
||||||
:param test_dataloader: DataLoader in testing
|
:param epochs: Maximum number of epoches
|
||||||
:param max_epochs: Maximum number of epoches
|
|
||||||
:param max_steps: Maximum number of running iterations
|
:param max_steps: Maximum number of running iterations
|
||||||
|
:param test_dataloader: DataLoader in testing
|
||||||
:param test_interval: Interval of testing
|
:param test_interval: Interval of testing
|
||||||
|
:param hooks_cfg: A list of hook configuration
|
||||||
:param display_progress: If True, the training progress will be printed
|
:param display_progress: If True, the training progress will be printed
|
||||||
:type train_dataloader: DataLoader
|
:type train_dataloader: DataLoader
|
||||||
:type test_dataloader: DataLoader
|
:type epochs: int
|
||||||
:type max_epochs: int
|
|
||||||
:type max_steps: int
|
:type max_steps: int
|
||||||
|
:type test_dataloader: DataLoader
|
||||||
:type test_interval: int
|
:type test_interval: int
|
||||||
|
:type hooks_cfg: dict
|
||||||
:type display_progress: bool
|
:type display_progress: bool
|
||||||
|
:type gradient_accumulation: int
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# prepare dataloaders
|
# set epochs and steps, consider gradient accumulation
|
||||||
self._train_dataloader = train_dataloader
|
self._steps_per_epoch = len(train_dataloader) // self._engine.gradient_accumulation
|
||||||
self._engine.set_dataloader(self._train_dataloader, train=True)
|
self._max_steps = max_steps
|
||||||
self._engine.train()
|
self._max_epochs = epochs
|
||||||
|
|
||||||
|
# check if testing is required
|
||||||
should_test = False
|
should_test = False
|
||||||
if test_dataloader is not None:
|
if test_dataloader is not None:
|
||||||
self._test_dataloader = test_dataloader
|
|
||||||
self._engine.set_dataloader(self._test_dataloader, train=False)
|
|
||||||
should_test = True
|
should_test = True
|
||||||
|
|
||||||
# decide the
|
display_progress = self._should_display_progress(display_progress)
|
||||||
if max_epochs is not None:
|
|
||||||
self._max_epochs = max_epochs
|
# reset hooks
|
||||||
if max_steps is not None:
|
self._reset_states()
|
||||||
self._max_steps = max_steps
|
self.hooks = list()
|
||||||
self._set_display_progress(display_progress)
|
|
||||||
|
# build hooks
|
||||||
|
if hooks_cfg is not None:
|
||||||
|
for cfg in hooks_cfg:
|
||||||
|
hook = build_hooks(cfg, self)
|
||||||
|
self.hooks.append(hook)
|
||||||
|
self.hooks.sort(key=lambda hook: hook.priority)
|
||||||
|
if self._verbose:
|
||||||
|
for hook in self.hooks:
|
||||||
|
self._logger.info(
|
||||||
|
f'build {hook.__class__.__name__} for training, priority = {hook.priority}', ranks=[0])
|
||||||
|
self._logger.info("Lower value means higher priority for calling hook function")
|
||||||
|
|
||||||
# start train
|
# start train
|
||||||
self.call_hooks('before_train')
|
self._engine.train()
|
||||||
|
self._call_hooks('before_train')
|
||||||
|
|
||||||
# recover step value if resuming training
|
# recover step value if resuming training
|
||||||
if self.cur_epoch != 0:
|
|
||||||
self._recover_steps()
|
|
||||||
|
|
||||||
last_epoch = self._cur_epoch
|
last_epoch = self._cur_epoch
|
||||||
|
if self.cur_epoch != 0:
|
||||||
|
self._set_current_step(last_epoch)
|
||||||
|
|
||||||
for epoch in range(last_epoch, self._max_epochs):
|
for epoch in range(last_epoch, epochs):
|
||||||
self._cur_epoch += 1
|
|
||||||
|
|
||||||
# train for one epoch
|
# train for one epoch
|
||||||
self._train_epoch(epoch)
|
self._train_epoch(
|
||||||
|
train_dataloader=train_dataloader,
|
||||||
|
epoch=epoch,
|
||||||
|
display_progress=display_progress
|
||||||
|
)
|
||||||
|
|
||||||
# start eval
|
# start eval
|
||||||
if should_test and epoch % test_interval == 0:
|
if should_test and epoch % test_interval == 0:
|
||||||
self._eval(epoch, return_loss=True)
|
self._eval(test_dataloader=test_dataloader,
|
||||||
|
display_progress=display_progress,
|
||||||
|
epoch=epoch,
|
||||||
|
)
|
||||||
|
|
||||||
|
self._cur_epoch += 1
|
||||||
|
|
||||||
# check for termination
|
# check for termination
|
||||||
if self.exceed_max_step():
|
if self._exceed_max_step():
|
||||||
self._logger.info(
|
self._logger.info(
|
||||||
f"Max number of steps {self._max_steps} has been reached, training is stopped automatically")
|
f"Max number of steps {max_steps} has been reached, training is stopped automatically")
|
||||||
break
|
break
|
||||||
self.call_hooks('after_train')
|
self._call_hooks('after_train')
|
||||||
self._timer.reset('train-epoch')
|
self._call_timer('reset', 'train-epoch')
|
||||||
|
|
||||||
def evaluate(self,
|
def evaluate(self,
|
||||||
test_dataloader: DataLoader,
|
test_dataloader: DataLoader,
|
||||||
|
@ -261,15 +324,13 @@ class Trainer:
|
||||||
:type test_dataloader: DataLoader
|
:type test_dataloader: DataLoader
|
||||||
:type display_progress: bool, optional
|
:type display_progress: bool, optional
|
||||||
"""
|
"""
|
||||||
# set dataloader
|
# set display
|
||||||
self._test_dataloader = test_dataloader
|
display_progress = self._should_display_progress(display_progress)
|
||||||
self._engine.set_dataloader(self._test_dataloader, train=True)
|
|
||||||
|
|
||||||
# set
|
|
||||||
self._set_display_progress(display_progress)
|
|
||||||
|
|
||||||
# eval
|
# eval
|
||||||
self._eval(return_loss=True)
|
self._eval(test_dataloader=test_dataloader,
|
||||||
|
display_progress=display_progress,
|
||||||
|
)
|
||||||
|
|
||||||
def predict(self, data: Union[Tensor, List[Tensor]]):
|
def predict(self, data: Union[Tensor, List[Tensor]]):
|
||||||
"""Uses trained model to make a prediction for a tensor or a tensor list.
|
"""Uses trained model to make a prediction for a tensor or a tensor list.
|
||||||
|
@ -289,45 +350,6 @@ class Trainer:
|
||||||
# prepare a list of (data, label) to make it iterable
|
# prepare a list of (data, label) to make it iterable
|
||||||
# for compatibility with schedule
|
# for compatibility with schedule
|
||||||
simple_dataloader = [(data, None)]
|
simple_dataloader = [(data, None)]
|
||||||
self._engine.set_dataloader(simple_dataloader)
|
data_iter = iter(simple_dataloader)
|
||||||
output, _, _ = self._engine.step(return_loss=False)
|
output, _, _ = self._engine.step(data_iter, return_loss=False)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
def save(self, path: str, suffix: str = ''):
|
|
||||||
"""Saves the model to a file.
|
|
||||||
|
|
||||||
:param path: Relative path of the file
|
|
||||||
:param suffix: Suffix of the file
|
|
||||||
:type path: str
|
|
||||||
:type suffix: str, optional
|
|
||||||
"""
|
|
||||||
save_path = get_checkpoint_path(path,
|
|
||||||
self._cur_epoch,
|
|
||||||
suffix=suffix)
|
|
||||||
save_checkpoint(save_path, self._cur_epoch, self._engine.get_model(),
|
|
||||||
self._engine.get_optimizer(),
|
|
||||||
self._engine.get_lr_scheduler())
|
|
||||||
|
|
||||||
def load(self,
|
|
||||||
path: str,
|
|
||||||
finetune: bool = False,
|
|
||||||
strict: bool = False):
|
|
||||||
"""Loads parameters to the model from a file.
|
|
||||||
|
|
||||||
:param path: Relative path of the file
|
|
||||||
:param finetune: Whether allows to load a part of the model
|
|
||||||
:param strict: Whether loads a model that has the same shape of parameters
|
|
||||||
:type path: str
|
|
||||||
:type finetune: bool, optional
|
|
||||||
:type strict: bool, optional
|
|
||||||
"""
|
|
||||||
last_epoch, _ = load_checkpoint(path,
|
|
||||||
self._engine.get_model(),
|
|
||||||
self._engine.get_optimizer(),
|
|
||||||
self._engine.get_lr_scheduler(),
|
|
||||||
finetune=finetune,
|
|
||||||
strict=strict)
|
|
||||||
if finetune:
|
|
||||||
self.set_epoch(0)
|
|
||||||
else:
|
|
||||||
self.set_epoch(last_epoch)
|
|
||||||
|
|
|
@ -2,10 +2,12 @@ from ._base_hook import BaseHook
|
||||||
from ._checkpoint_hook import SaveCheckpointHook, LoadCheckpointHook
|
from ._checkpoint_hook import SaveCheckpointHook, LoadCheckpointHook
|
||||||
from ._metric_hook import LossHook, Accuracy2DHook, AccuracyHook, MetricHook
|
from ._metric_hook import LossHook, Accuracy2DHook, AccuracyHook, MetricHook
|
||||||
from ._log_hook import LogMetricByEpochHook, TensorboardHook, LogTimingByEpochHook, LogMemoryByEpochHook
|
from ._log_hook import LogMetricByEpochHook, TensorboardHook, LogTimingByEpochHook, LogMemoryByEpochHook
|
||||||
|
from ._lr_scheduler_hook import LRSchedulerHook
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'BaseHook', 'MetricHook',
|
'BaseHook', 'MetricHook',
|
||||||
'LoadCheckpointHook', 'SaveCheckpointHook',
|
'LoadCheckpointHook', 'SaveCheckpointHook',
|
||||||
'LossHook', 'AccuracyHook', 'Accuracy2DHook',
|
'LossHook', 'AccuracyHook', 'Accuracy2DHook',
|
||||||
'LogMetricByEpochHook', 'TensorboardHook', 'LogTimingByEpochHook', 'LogMemoryByEpochHook',
|
'LogMetricByEpochHook', 'TensorboardHook', 'LogTimingByEpochHook', 'LogMemoryByEpochHook',
|
||||||
|
'LRSchedulerHook'
|
||||||
]
|
]
|
||||||
|
|
|
@ -3,13 +3,13 @@
|
||||||
|
|
||||||
import os.path as osp
|
import os.path as osp
|
||||||
|
|
||||||
import torch.distributed as dist
|
|
||||||
|
|
||||||
from colossalai.checkpointing import get_latest_checkpoint_path, get_checkpoint_path
|
|
||||||
from colossalai.registry import HOOKS
|
from colossalai.registry import HOOKS
|
||||||
from colossalai.trainer.hooks import BaseHook
|
|
||||||
from colossalai.trainer import Trainer
|
from colossalai.trainer import Trainer
|
||||||
|
from colossalai.trainer.hooks import BaseHook
|
||||||
from colossalai.utils import is_dp_rank_0
|
from colossalai.utils import is_dp_rank_0
|
||||||
|
from colossalai.utils.checkpointing import get_latest_checkpoint_path, get_checkpoint_path
|
||||||
|
from colossalai.utils.checkpointing import save_checkpoint, load_checkpoint
|
||||||
|
from ._lr_scheduler_hook import LRSchedulerHook
|
||||||
|
|
||||||
|
|
||||||
@HOOKS.register_module
|
@HOOKS.register_module
|
||||||
|
@ -33,7 +33,7 @@ class SaveCheckpointHook(BaseHook):
|
||||||
interval: int = 1,
|
interval: int = 1,
|
||||||
checkpoint_dir: str = None,
|
checkpoint_dir: str = None,
|
||||||
suffix: str = '',
|
suffix: str = '',
|
||||||
priority: int = 0):
|
priority: int = 10):
|
||||||
super().__init__(trainer=trainer, priority=priority)
|
super().__init__(trainer=trainer, priority=priority)
|
||||||
assert isinstance(trainer, Trainer), \
|
assert isinstance(trainer, Trainer), \
|
||||||
f'SaveCheckpointHook expects a Trainer, got {type(trainer)}'
|
f'SaveCheckpointHook expects a Trainer, got {type(trainer)}'
|
||||||
|
@ -41,6 +41,16 @@ class SaveCheckpointHook(BaseHook):
|
||||||
self.checkpoint_dir = checkpoint_dir
|
self.checkpoint_dir = checkpoint_dir
|
||||||
self.suffix = suffix
|
self.suffix = suffix
|
||||||
|
|
||||||
|
# get lr scheduler from the LRSchedulerHook before train
|
||||||
|
self._lr_scheduler = None
|
||||||
|
|
||||||
|
def before_train(self):
|
||||||
|
# check if lr scheduler is present in LRSchedulerHook
|
||||||
|
for hook in self.trainer.hooks:
|
||||||
|
if isinstance(hook, LRSchedulerHook):
|
||||||
|
self._lr_scheduler = hook.lr_scheduler
|
||||||
|
break
|
||||||
|
|
||||||
def after_train_epoch(self):
|
def after_train_epoch(self):
|
||||||
"""Saves the model after a training epoch.
|
"""Saves the model after a training epoch.
|
||||||
"""
|
"""
|
||||||
|
@ -48,14 +58,18 @@ class SaveCheckpointHook(BaseHook):
|
||||||
if self.trainer.cur_epoch % self.interval == 0:
|
if self.trainer.cur_epoch % self.interval == 0:
|
||||||
# only gpus with data parallel rank equals to 0 write to the disk
|
# only gpus with data parallel rank equals to 0 write to the disk
|
||||||
if is_dp_rank_0():
|
if is_dp_rank_0():
|
||||||
self.trainer.save(path=self.checkpoint_dir, suffix=self.suffix)
|
save_path = get_checkpoint_path(self.checkpoint_dir,
|
||||||
|
self.trainer.cur_epoch,
|
||||||
|
suffix=self.suffix)
|
||||||
|
|
||||||
|
save_checkpoint(save_path,
|
||||||
|
self.trainer.cur_epoch,
|
||||||
|
self.trainer.engine.model,
|
||||||
|
self.trainer.engine.optimizer,
|
||||||
|
self._lr_scheduler)
|
||||||
self.logger.info(
|
self.logger.info(
|
||||||
f'checkpoint for epoch {self.trainer.cur_epoch} is saved to {self.checkpoint_dir}')
|
f'checkpoint for epoch {self.trainer.cur_epoch} is saved to {self.checkpoint_dir}')
|
||||||
|
|
||||||
# wait until everyone is done
|
|
||||||
if dist.is_initialized():
|
|
||||||
dist.barrier()
|
|
||||||
|
|
||||||
|
|
||||||
@HOOKS.register_module
|
@HOOKS.register_module
|
||||||
class LoadCheckpointHook(BaseHook):
|
class LoadCheckpointHook(BaseHook):
|
||||||
|
@ -81,30 +95,46 @@ class LoadCheckpointHook(BaseHook):
|
||||||
epoch: int = -1,
|
epoch: int = -1,
|
||||||
finetune: bool = False,
|
finetune: bool = False,
|
||||||
strict: bool = False,
|
strict: bool = False,
|
||||||
priority: int = 10) -> None:
|
suffix: str = '',
|
||||||
|
priority: int = 0) -> None:
|
||||||
|
super().__init__(trainer=trainer, priority=priority)
|
||||||
assert isinstance(trainer, Trainer), \
|
assert isinstance(trainer, Trainer), \
|
||||||
f'LoadLatestCheckpointHook excepts a Trainer, got {type(trainer)}'
|
f'LoadLatestCheckpointHook excepts a Trainer, got {type(trainer)}'
|
||||||
self.epoch = epoch
|
self.epoch = epoch
|
||||||
self.checkpoint_dir = checkpoint_dir
|
self.checkpoint_dir = checkpoint_dir
|
||||||
self.finetune = finetune
|
self.finetune = finetune
|
||||||
|
self.suffix = suffix
|
||||||
self.strict = strict
|
self.strict = strict
|
||||||
super().__init__(trainer=trainer, priority=priority)
|
|
||||||
|
|
||||||
def before_train(self):
|
def before_train(self):
|
||||||
"""Loads parameters to the model before training.
|
"""Loads parameters to the model before training.
|
||||||
"""
|
"""
|
||||||
|
# check if lr scheduler is present in LRSchedulerHook
|
||||||
|
lr_scheduler = None
|
||||||
|
for hook in self.trainer.hooks:
|
||||||
|
if isinstance(hook, LRSchedulerHook):
|
||||||
|
lr_scheduler = hook.lr_scheduler
|
||||||
|
break
|
||||||
|
|
||||||
|
# use latest checkpoint if epoch = -1
|
||||||
if self.epoch == -1:
|
if self.epoch == -1:
|
||||||
path = get_latest_checkpoint_path(self.checkpoint_dir)
|
path = get_latest_checkpoint_path(self.checkpoint_dir, suffix=self.suffix)
|
||||||
else:
|
else:
|
||||||
path = get_checkpoint_path(self.checkpoint_dir, epoch=self.epoch)
|
path = get_checkpoint_path(self.checkpoint_dir, epoch=self.epoch, suffix=self.suffix)
|
||||||
|
|
||||||
if osp.exists(path):
|
if osp.exists(path):
|
||||||
self.trainer.load(
|
last_epoch, _ = load_checkpoint(path,
|
||||||
path, finetune=self.finetune, strict=self.strict)
|
self.trainer.engine.model,
|
||||||
|
self.trainer.engine.optimizer,
|
||||||
|
lr_scheduler,
|
||||||
|
finetune=self.finetune,
|
||||||
|
strict=self.strict)
|
||||||
|
if self.finetune:
|
||||||
|
self.trainer.cur_epoch = 0
|
||||||
|
else:
|
||||||
|
self.trainer.cur_epoch = last_epoch
|
||||||
|
|
||||||
self.logger.info(
|
self.logger.info(
|
||||||
f'loaded checkpoint from {path}')
|
f'loaded checkpoint from {path}')
|
||||||
else:
|
else:
|
||||||
raise FileNotFoundError(f'checkpoint is not found at {path}')
|
raise FileNotFoundError(f'checkpoint is not found at {path}')
|
||||||
|
|
||||||
# Some utilities want to load a checkpoint without distributed being initialized
|
|
||||||
if dist.is_initialized():
|
|
||||||
dist.barrier()
|
|
||||||
|
|
|
@ -5,7 +5,7 @@ import os
|
||||||
import os.path as osp
|
import os.path as osp
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from tensorboardX import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
from colossalai.context import ParallelMode
|
from colossalai.context import ParallelMode
|
||||||
from colossalai.core import global_context as gpc
|
from colossalai.core import global_context as gpc
|
||||||
|
@ -13,7 +13,7 @@ from colossalai.registry import HOOKS
|
||||||
from colossalai.trainer._trainer import Trainer
|
from colossalai.trainer._trainer import Trainer
|
||||||
from colossalai.utils import get_global_multitimer, set_global_multitimer_status, report_memory_usage, is_dp_rank_0, \
|
from colossalai.utils import get_global_multitimer, set_global_multitimer_status, report_memory_usage, is_dp_rank_0, \
|
||||||
is_tp_rank_0, is_no_pp_or_last_stage
|
is_tp_rank_0, is_no_pp_or_last_stage
|
||||||
from ._metric_hook import MetricHook
|
from ._base_hook import BaseHook
|
||||||
|
|
||||||
|
|
||||||
def _format_number(val):
|
def _format_number(val):
|
||||||
|
@ -24,7 +24,7 @@ def _format_number(val):
|
||||||
return val
|
return val
|
||||||
|
|
||||||
|
|
||||||
class EpochIntervalHook(MetricHook):
|
class EpochIntervalHook(BaseHook):
|
||||||
def __init__(self, trainer: Trainer, interval: int = 1, priority: int = 1):
|
def __init__(self, trainer: Trainer, interval: int = 1, priority: int = 1):
|
||||||
super().__init__(trainer, priority)
|
super().__init__(trainer, priority)
|
||||||
self._interval = interval
|
self._interval = interval
|
||||||
|
@ -45,7 +45,7 @@ class LogMetricByEpochHook(EpochIntervalHook):
|
||||||
:type priority: int, optional
|
:type priority: int, optional
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, trainer: Trainer, interval: int = 1, priority: int = 1) -> None:
|
def __init__(self, trainer: Trainer, interval: int = 1, priority: int = 10) -> None:
|
||||||
super().__init__(trainer=trainer, interval=interval, priority=priority)
|
super().__init__(trainer=trainer, interval=interval, priority=priority)
|
||||||
self._is_rank_to_log = is_dp_rank_0() and is_tp_rank_0() and is_no_pp_or_last_stage()
|
self._is_rank_to_log = is_dp_rank_0() and is_tp_rank_0() and is_no_pp_or_last_stage()
|
||||||
|
|
||||||
|
@ -74,7 +74,7 @@ class LogMetricByEpochHook(EpochIntervalHook):
|
||||||
|
|
||||||
|
|
||||||
@HOOKS.register_module
|
@HOOKS.register_module
|
||||||
class TensorboardHook(MetricHook):
|
class TensorboardHook(BaseHook):
|
||||||
"""Specialized Hook to record the metric to Tensorboard.
|
"""Specialized Hook to record the metric to Tensorboard.
|
||||||
|
|
||||||
:param trainer: Trainer attached with current hook
|
:param trainer: Trainer attached with current hook
|
||||||
|
@ -85,59 +85,71 @@ class TensorboardHook(MetricHook):
|
||||||
:type priority: int, optional
|
:type priority: int, optional
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, trainer: Trainer, log_dir: str, priority: int = 1) -> None:
|
def __init__(self,
|
||||||
|
trainer: Trainer,
|
||||||
|
log_dir: str,
|
||||||
|
dp_rank_0_only: bool = True,
|
||||||
|
tp_rank_0_only: bool = True,
|
||||||
|
priority: int = 10,
|
||||||
|
) -> None:
|
||||||
super().__init__(trainer=trainer, priority=priority)
|
super().__init__(trainer=trainer, priority=priority)
|
||||||
self._is_rank_to_log = is_no_pp_or_last_stage()
|
|
||||||
|
|
||||||
if self._is_rank_to_log:
|
# create log dir
|
||||||
|
if not gpc.is_initialized(ParallelMode.GLOBAL) or gpc.get_global_rank() == 0:
|
||||||
|
os.makedirs(log_dir, exist_ok=True)
|
||||||
|
|
||||||
|
# determine the ranks to generate tensorboard logs
|
||||||
|
self._is_valid_rank_to_log = is_no_pp_or_last_stage()
|
||||||
|
|
||||||
|
if dp_rank_0_only:
|
||||||
|
self._is_valid_rank_to_log = self._is_valid_rank_to_log and is_dp_rank_0()
|
||||||
|
|
||||||
|
if tp_rank_0_only:
|
||||||
|
self._is_valid_rank_to_log = self._is_valid_rank_to_log and is_tp_rank_0()
|
||||||
|
|
||||||
|
if self._is_valid_rank_to_log:
|
||||||
# create workspace on only one rank
|
# create workspace on only one rank
|
||||||
if gpc.is_initialized(ParallelMode.GLOBAL):
|
if gpc.is_initialized(ParallelMode.GLOBAL):
|
||||||
rank = gpc.get_global_rank()
|
rank = gpc.get_global_rank()
|
||||||
else:
|
else:
|
||||||
rank = 0
|
rank = 0
|
||||||
|
|
||||||
log_dir = osp.join(log_dir, f'rank_{rank}')
|
|
||||||
|
|
||||||
# create workspace
|
# create workspace
|
||||||
if not osp.exists(log_dir):
|
log_dir = osp.join(log_dir, f'rank_{rank}')
|
||||||
os.makedirs(log_dir)
|
os.makedirs(log_dir, exist_ok=True)
|
||||||
|
|
||||||
self.writer = SummaryWriter(
|
self.writer = SummaryWriter(
|
||||||
log_dir=log_dir, filename_suffix=f'_rank_{rank}')
|
log_dir=log_dir, filename_suffix=f'_rank_{rank}')
|
||||||
|
|
||||||
def after_train_iter(self, *args):
|
def _log_by_iter(self, mode: str):
|
||||||
for metric_name, metric_calculator in self.trainer.states['metrics']['train'].items():
|
for metric_name, metric_calculator in self.trainer.states['metrics'][mode].items():
|
||||||
if metric_calculator.epoch_only:
|
if metric_calculator.epoch_only:
|
||||||
continue
|
continue
|
||||||
val = metric_calculator.get_last_step_value()
|
val = metric_calculator.get_last_step_value()
|
||||||
if self._is_rank_to_log:
|
|
||||||
self.writer.add_scalar(
|
if self._is_valid_rank_to_log:
|
||||||
f'{metric_name}/train', val, self.trainer.cur_step)
|
self.writer.add_scalar(f'{metric_name}/{mode}', val,
|
||||||
|
self.trainer.cur_step)
|
||||||
|
|
||||||
|
def _log_by_epoch(self, mode: str):
|
||||||
|
for metric_name, metric_calculator in self.trainer.states['metrics'][mode].items():
|
||||||
|
if metric_calculator.epoch_only:
|
||||||
|
val = metric_calculator.get_accumulated_value()
|
||||||
|
if self._is_valid_rank_to_log:
|
||||||
|
self.writer.add_scalar(f'{metric_name}/{mode}', val,
|
||||||
|
self.trainer.cur_step)
|
||||||
|
|
||||||
def after_test_iter(self, *args):
|
def after_test_iter(self, *args):
|
||||||
for metric_name, metric_calculator in self.trainer.states['metrics']['test'].items():
|
self._log_by_iter(mode='test')
|
||||||
if metric_calculator.epoch_only:
|
|
||||||
continue
|
|
||||||
val = metric_calculator.get_last_step_value()
|
|
||||||
if self._is_rank_to_log:
|
|
||||||
self.writer.add_scalar(f'{metric_name}/test', val,
|
|
||||||
self.trainer.cur_step)
|
|
||||||
|
|
||||||
def after_test_epoch(self):
|
def after_test_epoch(self):
|
||||||
for metric_name, metric_calculator in self.trainer.states['metrics']['test'].items():
|
self._log_by_epoch(mode='test')
|
||||||
if metric_calculator.epoch_only:
|
|
||||||
val = metric_calculator.get_accumulated_value()
|
def after_train_iter(self, *args):
|
||||||
if self._is_rank_to_log:
|
self._log_by_iter(mode='train')
|
||||||
self.writer.add_scalar(f'{metric_name}/test', val,
|
|
||||||
self.trainer.cur_step)
|
|
||||||
|
|
||||||
def after_train_epoch(self):
|
def after_train_epoch(self):
|
||||||
for metric_name, metric_calculator in self.trainer.states['metrics']['train'].items():
|
self._log_by_epoch(mode='train')
|
||||||
if metric_calculator.epoch_only:
|
|
||||||
val = metric_calculator.get_accumulated_value()
|
|
||||||
if self._is_rank_to_log:
|
|
||||||
self.writer.add_scalar(f'{metric_name}/train', val,
|
|
||||||
self.trainer.cur_step)
|
|
||||||
|
|
||||||
|
|
||||||
@HOOKS.register_module
|
@HOOKS.register_module
|
||||||
|
@ -157,7 +169,7 @@ class LogTimingByEpochHook(EpochIntervalHook):
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
trainer: Trainer,
|
trainer: Trainer,
|
||||||
interval: int = 1,
|
interval: int = 1,
|
||||||
priority: int = 1,
|
priority: int = 10,
|
||||||
log_eval: bool = True
|
log_eval: bool = True
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(trainer=trainer, interval=interval, priority=priority)
|
super().__init__(trainer=trainer, interval=interval, priority=priority)
|
||||||
|
@ -217,7 +229,7 @@ class LogMemoryByEpochHook(EpochIntervalHook):
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
trainer: Trainer,
|
trainer: Trainer,
|
||||||
interval: int = 1,
|
interval: int = 1,
|
||||||
priority: int = 1,
|
priority: int = 10,
|
||||||
log_eval: bool = True
|
log_eval: bool = True
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(trainer=trainer, interval=interval, priority=priority)
|
super().__init__(trainer=trainer, interval=interval, priority=priority)
|
||||||
|
|
|
@ -0,0 +1,58 @@
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
|
from colossalai.builder import build_lr_scheduler
|
||||||
|
from colossalai.registry import HOOKS
|
||||||
|
from ._metric_hook import MetricHook
|
||||||
|
from .._trainer import Trainer
|
||||||
|
from ..metric import LearningRate
|
||||||
|
|
||||||
|
|
||||||
|
@HOOKS.register_module
|
||||||
|
class LRSchedulerHook(MetricHook):
|
||||||
|
"""Build LR scheduler
|
||||||
|
|
||||||
|
:param trainer: Trainer attached with current hook
|
||||||
|
:type trainer: Trainer
|
||||||
|
:param lr_scheduler_cfg: The config of LR scheduler
|
||||||
|
:type lr_scheduler_cfg: dict
|
||||||
|
:param by_epoch: If `True`, the LR will be scheduled every epoch. Else, the LR will be scheduled every batch. Defaults to `True`.
|
||||||
|
:type by_epoch: bool
|
||||||
|
:param priority: Priority in the printing, hooks with small priority will be printed in front
|
||||||
|
:type priority: int, optional
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
trainer: Trainer,
|
||||||
|
lr_scheduler_cfg: dict,
|
||||||
|
by_epoch: bool = True,
|
||||||
|
store_lr_in_state: bool = True,
|
||||||
|
priority: int = 1,
|
||||||
|
):
|
||||||
|
super().__init__(trainer=trainer, priority=priority)
|
||||||
|
self.by_epoch = by_epoch
|
||||||
|
|
||||||
|
if by_epoch:
|
||||||
|
total_steps = trainer.max_epochs
|
||||||
|
else:
|
||||||
|
total_steps = trainer.max_epochs * trainer.steps_per_epoch
|
||||||
|
if trainer.max_steps is not None:
|
||||||
|
total_steps = min(total_steps, trainer.max_steps)
|
||||||
|
|
||||||
|
lr_scheduler_cfg['total_steps'] = total_steps
|
||||||
|
|
||||||
|
self.lr_scheduler = build_lr_scheduler(
|
||||||
|
lr_scheduler_cfg, trainer.engine.optimizer)
|
||||||
|
|
||||||
|
if store_lr_in_state:
|
||||||
|
self.trainer.states['metrics']['train']['lr'] = LearningRate(epoch_only=by_epoch,
|
||||||
|
initial_lr=self.lr_scheduler.get_lr()[0])
|
||||||
|
|
||||||
|
def after_train_epoch(self):
|
||||||
|
if self.by_epoch:
|
||||||
|
self.lr_scheduler.step()
|
||||||
|
self.trainer.states['metrics']['train']['lr'].update(self.lr_scheduler.get_lr()[0])
|
||||||
|
|
||||||
|
def after_train_iter(self, output: Tensor, label: Tensor, loss: Tensor):
|
||||||
|
if not self.by_epoch:
|
||||||
|
self.lr_scheduler.step()
|
||||||
|
self.trainer.states['metrics']['train']['lr'].update(self.lr_scheduler.get_lr()[0])
|
|
@ -21,9 +21,12 @@ class MetricHook(BaseHook):
|
||||||
:type priority: int
|
:type priority: int
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, trainer: Trainer, priority: int):
|
def __init__(self,
|
||||||
|
trainer: Trainer,
|
||||||
|
priority: int,
|
||||||
|
):
|
||||||
super().__init__(trainer, priority)
|
super().__init__(trainer, priority)
|
||||||
self._is_stage_to_log = is_no_pp_or_last_stage()
|
self._is_stage_to_compute = is_no_pp_or_last_stage()
|
||||||
self._check_metric_states_initialization()
|
self._check_metric_states_initialization()
|
||||||
|
|
||||||
def _check_metric_states_initialization(self):
|
def _check_metric_states_initialization(self):
|
||||||
|
@ -41,33 +44,34 @@ class LossHook(MetricHook):
|
||||||
:type priority: int, optional
|
:type priority: int, optional
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, trainer: Trainer, priority: int = 10):
|
def __init__(self, trainer: Trainer, priority: int = 0):
|
||||||
super().__init__(trainer, priority)
|
super().__init__(trainer, priority)
|
||||||
|
|
||||||
if self._is_stage_to_log:
|
if self._is_stage_to_compute:
|
||||||
self.metric = Loss(epoch_only=False)
|
self.train_loss = Loss(epoch_only=False)
|
||||||
|
self.test_loss = Loss(epoch_only=True)
|
||||||
|
|
||||||
# register the metric calculator
|
# register the metric calculator
|
||||||
self.trainer.states['metrics']['train'][
|
self.trainer.states['metrics']['train'][
|
||||||
self.metric.__class__.__name__] = self.metric
|
self.train_loss.__class__.__name__] = self.train_loss
|
||||||
self.trainer.states['metrics']['test'][
|
self.trainer.states['metrics']['test'][
|
||||||
self.metric.__class__.__name__] = self.metric
|
self.test_loss.__class__.__name__] = self.test_loss
|
||||||
|
|
||||||
def before_train_epoch(self):
|
def before_train_epoch(self):
|
||||||
if self._is_stage_to_log:
|
if self._is_stage_to_compute:
|
||||||
self.metric.reset()
|
self.train_loss.reset()
|
||||||
|
|
||||||
def after_train_iter(self, logits, label, loss):
|
def after_train_iter(self, logits, label, loss):
|
||||||
if self._is_stage_to_log:
|
if self._is_stage_to_compute:
|
||||||
self.metric.update(loss)
|
self.train_loss.update(loss)
|
||||||
|
|
||||||
def before_test_epoch(self):
|
def before_test_epoch(self):
|
||||||
if self._is_stage_to_log:
|
if self._is_stage_to_compute:
|
||||||
self.metric.reset()
|
self.test_loss.reset()
|
||||||
|
|
||||||
def after_test_iter(self, logits, label, loss):
|
def after_test_iter(self, logits, label, loss):
|
||||||
if self._is_stage_to_log:
|
if self._is_stage_to_compute:
|
||||||
self.metric.update(loss)
|
self.test_loss.update(loss)
|
||||||
|
|
||||||
|
|
||||||
@HOOKS.register_module
|
@HOOKS.register_module
|
||||||
|
@ -81,10 +85,10 @@ class Accuracy2DHook(MetricHook):
|
||||||
:type priority: int, optional
|
:type priority: int, optional
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, trainer: Trainer, priority: int = 10):
|
def __init__(self, trainer: Trainer, priority: int = 0):
|
||||||
super().__init__(trainer, priority)
|
super().__init__(trainer, priority)
|
||||||
|
|
||||||
if self._is_stage_to_log:
|
if self._is_stage_to_compute:
|
||||||
self.metric = Accuracy2D(epoch_only=True)
|
self.metric = Accuracy2D(epoch_only=True)
|
||||||
|
|
||||||
# register the metric
|
# register the metric
|
||||||
|
@ -92,20 +96,20 @@ class Accuracy2DHook(MetricHook):
|
||||||
self.metric.__class__.__name__] = self.metric
|
self.metric.__class__.__name__] = self.metric
|
||||||
|
|
||||||
def before_test(self):
|
def before_test(self):
|
||||||
if self._is_stage_to_log:
|
if self._is_stage_to_compute:
|
||||||
self.metric.reset()
|
self.metric.reset()
|
||||||
|
|
||||||
def after_test_iter(self, logits, label, *args):
|
def after_test_iter(self, logits, label, *args):
|
||||||
if self._is_stage_to_log:
|
if self._is_stage_to_compute:
|
||||||
self.metric.update(logits, label)
|
self.metric.update(logits, label)
|
||||||
|
|
||||||
|
|
||||||
@HOOKS.register_module
|
@HOOKS.register_module
|
||||||
class Accuracy2p5DHook(MetricHook):
|
class Accuracy2p5DHook(MetricHook):
|
||||||
def __init__(self, trainer: Trainer, priority: int = 10):
|
def __init__(self, trainer: Trainer, priority: int = 0):
|
||||||
super().__init__(trainer, priority)
|
super().__init__(trainer, priority)
|
||||||
|
|
||||||
if self._is_stage_to_log:
|
if self._is_stage_to_compute:
|
||||||
self.metric = Accuracy2p5D(epoch_only=True)
|
self.metric = Accuracy2p5D(epoch_only=True)
|
||||||
|
|
||||||
# register the metric
|
# register the metric
|
||||||
|
@ -113,11 +117,11 @@ class Accuracy2p5DHook(MetricHook):
|
||||||
self.metric.__class__.__name__] = self.metric
|
self.metric.__class__.__name__] = self.metric
|
||||||
|
|
||||||
def before_test(self):
|
def before_test(self):
|
||||||
if self._is_stage_to_log:
|
if self._is_stage_to_compute:
|
||||||
self.metric.reset()
|
self.metric.reset()
|
||||||
|
|
||||||
def after_test_iter(self, logits, label, *args):
|
def after_test_iter(self, logits, label, *args):
|
||||||
if self._is_stage_to_log:
|
if self._is_stage_to_compute:
|
||||||
self.metric.update(logits, label)
|
self.metric.update(logits, label)
|
||||||
|
|
||||||
|
|
||||||
|
@ -138,7 +142,7 @@ class Accuracy3DHook(MetricHook):
|
||||||
priority: int = 10):
|
priority: int = 10):
|
||||||
super().__init__(trainer, priority)
|
super().__init__(trainer, priority)
|
||||||
|
|
||||||
if self._is_stage_to_log:
|
if self._is_stage_to_compute:
|
||||||
self.metric = Accuracy3D(epoch_only=True,
|
self.metric = Accuracy3D(epoch_only=True,
|
||||||
input_parallel_mode=input_parallel_mode,
|
input_parallel_mode=input_parallel_mode,
|
||||||
weight_parallel_mode=weight_parallel_mode)
|
weight_parallel_mode=weight_parallel_mode)
|
||||||
|
@ -148,11 +152,11 @@ class Accuracy3DHook(MetricHook):
|
||||||
self.metric.__class__.__name__] = self.metric
|
self.metric.__class__.__name__] = self.metric
|
||||||
|
|
||||||
def before_test(self):
|
def before_test(self):
|
||||||
if self._is_stage_to_log:
|
if self._is_stage_to_compute:
|
||||||
self.metric.reset()
|
self.metric.reset()
|
||||||
|
|
||||||
def after_test_iter(self, logits, label, *args):
|
def after_test_iter(self, logits, label, *args):
|
||||||
if self._is_stage_to_log:
|
if self._is_stage_to_compute:
|
||||||
self.metric.update(logits, label)
|
self.metric.update(logits, label)
|
||||||
|
|
||||||
|
|
||||||
|
@ -166,10 +170,10 @@ class AccuracyHook(MetricHook):
|
||||||
:type priority: int
|
:type priority: int
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, trainer: Trainer, priority: int = 10):
|
def __init__(self, trainer: Trainer, priority: int = 0):
|
||||||
super().__init__(trainer, priority)
|
super().__init__(trainer, priority)
|
||||||
|
|
||||||
if self._is_stage_to_log:
|
if self._is_stage_to_compute:
|
||||||
self.metric = Accuracy(epoch_only=True)
|
self.metric = Accuracy(epoch_only=True)
|
||||||
|
|
||||||
# register the metric
|
# register the metric
|
||||||
|
@ -177,9 +181,9 @@ class AccuracyHook(MetricHook):
|
||||||
self.metric.__class__.__name__] = self.metric
|
self.metric.__class__.__name__] = self.metric
|
||||||
|
|
||||||
def before_test(self):
|
def before_test(self):
|
||||||
if self._is_stage_to_log:
|
if self._is_stage_to_compute:
|
||||||
self.metric.reset()
|
self.metric.reset()
|
||||||
|
|
||||||
def after_test_iter(self, logits, label, *args):
|
def after_test_iter(self, logits, label, *args):
|
||||||
if self._is_stage_to_log:
|
if self._is_stage_to_compute:
|
||||||
self.metric.update(logits, label)
|
self.metric.update(logits, label)
|
||||||
|
|
|
@ -126,6 +126,33 @@ class Loss(Metric):
|
||||||
return a < b
|
return a < b
|
||||||
|
|
||||||
|
|
||||||
|
class LearningRate(Metric):
|
||||||
|
"""A metric collector for learning rate.
|
||||||
|
|
||||||
|
:param epoch_only: Whether the metric only read for the full epoch
|
||||||
|
:type epoch_only: bool
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, epoch_only: bool, initial_lr: float = 0.):
|
||||||
|
super().__init__(epoch_only=epoch_only)
|
||||||
|
self.lr = 0.
|
||||||
|
|
||||||
|
def reset(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def update(self, lr) -> None:
|
||||||
|
self.lr = lr
|
||||||
|
|
||||||
|
def get_last_step_value(self):
|
||||||
|
return self.lr
|
||||||
|
|
||||||
|
def get_accumulated_value(self):
|
||||||
|
return self.lr
|
||||||
|
|
||||||
|
def is_better(a, b) -> bool:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class Accuracy(Metric):
|
class Accuracy(Metric):
|
||||||
"""A metric collector for accuracy. It only works for classification
|
"""A metric collector for accuracy. It only works for classification
|
||||||
tasks.
|
tasks.
|
||||||
|
|
|
@ -5,9 +5,9 @@ from typing import Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from .context import Config
|
from colossalai.context import Config
|
||||||
from .context.parallel_mode import ParallelMode
|
from colossalai.context.parallel_mode import ParallelMode
|
||||||
from .core import global_context as gpc
|
from colossalai.core import global_context as gpc
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'get_checkpoint_path',
|
'get_checkpoint_path',
|
|
@ -27,7 +27,7 @@ def sync_model_param_in_dp(model):
|
||||||
:param model: A pyTorch nn.model on whose parameters you check the consistency
|
:param model: A pyTorch nn.model on whose parameters you check the consistency
|
||||||
'''
|
'''
|
||||||
|
|
||||||
if gpc.is_initialized(ParallelMode.DATA) and gpc.get_world_size(ParallelMode.DATA) > 2:
|
if gpc.is_initialized(ParallelMode.DATA) and gpc.get_world_size(ParallelMode.DATA) > 1:
|
||||||
for param in model.parameters():
|
for param in model.parameters():
|
||||||
ranks = gpc.get_ranks_in_group(ParallelMode.DATA)
|
ranks = gpc.get_ranks_in_group(ParallelMode.DATA)
|
||||||
dist.broadcast(param, src=ranks[0], group=gpc.get_group(ParallelMode.DATA))
|
dist.broadcast(param, src=ranks[0], group=gpc.get_group(ParallelMode.DATA))
|
||||||
|
|
|
@ -4,6 +4,7 @@ import os
|
||||||
|
|
||||||
IMG_SIZE = 224
|
IMG_SIZE = 224
|
||||||
BATCH_SIZE = 256
|
BATCH_SIZE = 256
|
||||||
|
NUM_EPOCHS = 100
|
||||||
|
|
||||||
model = dict(
|
model = dict(
|
||||||
type='VanillaResNet',
|
type='VanillaResNet',
|
||||||
|
@ -67,8 +68,6 @@ loss = dict(
|
||||||
type='CrossEntropyLoss'
|
type='CrossEntropyLoss'
|
||||||
)
|
)
|
||||||
|
|
||||||
max_epochs = 100
|
|
||||||
|
|
||||||
from colossalai.engine import AMP_TYPE
|
from colossalai.engine import AMP_TYPE
|
||||||
|
|
||||||
fp16 = dict(
|
fp16 = dict(
|
||||||
|
|
|
@ -1,21 +1,20 @@
|
||||||
#!/usr/bin/env python
|
#!/usr/bin/env python
|
||||||
# -*- encoding: utf-8 -*-
|
# -*- encoding: utf-8 -*-
|
||||||
|
|
||||||
|
NUM_EPOCH = int
|
||||||
|
|
||||||
model = dict()
|
model = dict()
|
||||||
train_data = dict()
|
train_data = dict()
|
||||||
test_data = dict()
|
test_data = dict()
|
||||||
optimizer = dict()
|
optimizer = dict()
|
||||||
loss = dict()
|
loss = dict()
|
||||||
lr_scheduler = dict()
|
|
||||||
|
|
||||||
fp16 = dict()
|
fp16 = dict()
|
||||||
zero = dict()
|
zero = dict()
|
||||||
|
|
||||||
gradient_handler = []
|
gradient_handler = []
|
||||||
parallel = dict()
|
parallel = dict()
|
||||||
|
hooks = []
|
||||||
num_epochs = int
|
|
||||||
num_steps = int
|
|
||||||
|
|
||||||
cudnn_benchmark = True
|
cudnn_benchmark = True
|
||||||
cudnn_deterministic = False
|
cudnn_deterministic = False
|
||||||
|
|
|
@ -8,10 +8,11 @@ BATCH_SIZE = 512
|
||||||
IMG_SIZE = 32
|
IMG_SIZE = 32
|
||||||
PATCH_SIZE = 4
|
PATCH_SIZE = 4
|
||||||
DIM = 512
|
DIM = 512
|
||||||
NUM_ATTENTION_HEADS = 8
|
NUM_ATTENTION_HEADS = 2
|
||||||
SUMMA_DIM = 2
|
SUMMA_DIM = 2
|
||||||
NUM_CLASSES = 10
|
NUM_CLASSES = 10
|
||||||
DEPTH = 6
|
DEPTH = 1
|
||||||
|
NUM_EPOCHS = 60
|
||||||
|
|
||||||
train_data = dict(
|
train_data = dict(
|
||||||
dataset=dict(
|
dataset=dict(
|
||||||
|
@ -127,14 +128,22 @@ hooks = [
|
||||||
dict(type='LogMetricByEpochHook'),
|
dict(type='LogMetricByEpochHook'),
|
||||||
dict(type='Accuracy2DHook'),
|
dict(type='Accuracy2DHook'),
|
||||||
dict(type='LossHook'),
|
dict(type='LossHook'),
|
||||||
dict(type='TensorboardHook', log_dir='./tfb_logs'),
|
dict(
|
||||||
|
type='LRSchedulerHook',
|
||||||
|
by_epoch=True,
|
||||||
|
lr_scheduler_cfg=dict(
|
||||||
|
type='LinearWarmupLR',
|
||||||
|
warmup_steps=5
|
||||||
|
)
|
||||||
|
),
|
||||||
|
dict(type='TensorboardHook', log_dir='./tb_logs'),
|
||||||
# dict(type='SaveCheckpointHook', interval=5, checkpoint_dir='./ckpt'),
|
# dict(type='SaveCheckpointHook', interval=5, checkpoint_dir='./ckpt'),
|
||||||
# dict(type='LoadCheckpointHook', epoch=20, checkpoint_dir='./ckpt')
|
# dict(type='LoadCheckpointHook', epoch=20, checkpoint_dir='./ckpt')
|
||||||
]
|
]
|
||||||
|
|
||||||
parallel = dict(
|
parallel = dict(
|
||||||
pipeline=dict(size=1),
|
pipeline=dict(size=1),
|
||||||
tensor=dict(size=4, mode='2d'),
|
tensor=dict(size=1, mode='2d'),
|
||||||
)
|
)
|
||||||
|
|
||||||
# for fp16 training
|
# for fp16 training
|
||||||
|
@ -144,17 +153,11 @@ parallel = dict(
|
||||||
# initial_scale=2 ** 8
|
# initial_scale=2 ** 8
|
||||||
# )
|
# )
|
||||||
|
|
||||||
lr_scheduler = dict(
|
|
||||||
type='LinearWarmupLR',
|
|
||||||
warmup_epochs=5
|
|
||||||
)
|
|
||||||
|
|
||||||
# only needed when pipeline parallel is used
|
# only needed when pipeline parallel is used
|
||||||
# schedule = dict(
|
# schedule = dict(
|
||||||
# num_microbatches=8
|
# num_microbatches=8
|
||||||
# )
|
# )
|
||||||
|
|
||||||
num_epochs = 60
|
|
||||||
|
|
||||||
logging = dict(
|
logging = dict(
|
||||||
root_path='./logs'
|
root_path='./logs'
|
||||||
|
|
|
@ -14,6 +14,7 @@ except:
|
||||||
|
|
||||||
BATCH_SIZE = 512
|
BATCH_SIZE = 512
|
||||||
IMG_SIZE = 32
|
IMG_SIZE = 32
|
||||||
|
NUM_EPOCHS = 60
|
||||||
|
|
||||||
train_data = dict(
|
train_data = dict(
|
||||||
dataset=dict(
|
dataset=dict(
|
||||||
|
@ -83,6 +84,14 @@ hooks = [
|
||||||
),
|
),
|
||||||
dict(type='LossHook'),
|
dict(type='LossHook'),
|
||||||
dict(type='TensorboardHook', log_dir='./tfb_logs'),
|
dict(type='TensorboardHook', log_dir='./tfb_logs'),
|
||||||
|
dict(
|
||||||
|
type='LRSchedulerHook',
|
||||||
|
by_epoch=True,
|
||||||
|
lr_scheduler_cfg=dict(
|
||||||
|
type='LinearWarmupLR',
|
||||||
|
warmup_steps=5
|
||||||
|
)
|
||||||
|
),
|
||||||
# dict(type='SaveCheckpointHook', interval=5, checkpoint_dir='./ckpt'),
|
# dict(type='SaveCheckpointHook', interval=5, checkpoint_dir='./ckpt'),
|
||||||
# dict(type='LoadCheckpointHook', epoch=20, checkpoint_dir='./ckpt')
|
# dict(type='LoadCheckpointHook', epoch=20, checkpoint_dir='./ckpt')
|
||||||
]
|
]
|
||||||
|
@ -97,13 +106,6 @@ fp16 = dict(
|
||||||
initial_scale=2 ** 8
|
initial_scale=2 ** 8
|
||||||
)
|
)
|
||||||
|
|
||||||
lr_scheduler = dict(
|
|
||||||
type='LinearWarmupLR',
|
|
||||||
warmup_epochs=5
|
|
||||||
)
|
|
||||||
|
|
||||||
num_epochs = 60
|
|
||||||
|
|
||||||
logging = dict(
|
logging = dict(
|
||||||
root_path='./logs'
|
root_path='./logs'
|
||||||
)
|
)
|
||||||
|
|
|
@ -0,0 +1,5 @@
|
||||||
|
colossalai.engine.amp.amp\_type
|
||||||
|
===============================
|
||||||
|
|
||||||
|
.. automodule:: colossalai.engine.amp.amp_type
|
||||||
|
:members:
|
|
@ -0,0 +1,5 @@
|
||||||
|
colossalai.engine.amp.grad\_scaler
|
||||||
|
==================================
|
||||||
|
|
||||||
|
.. automodule:: colossalai.engine.amp.grad_scaler
|
||||||
|
:members:
|
|
@ -0,0 +1,12 @@
|
||||||
|
colossalai.engine.amp
|
||||||
|
=====================
|
||||||
|
|
||||||
|
.. automodule:: colossalai.engine.amp
|
||||||
|
:members:
|
||||||
|
|
||||||
|
|
||||||
|
.. toctree::
|
||||||
|
:maxdepth: 2
|
||||||
|
|
||||||
|
colossalai.engine.amp.amp_type
|
||||||
|
colossalai.engine.amp.grad_scaler
|
|
@ -1,5 +0,0 @@
|
||||||
colossalai.engine.amp\_type
|
|
||||||
===========================
|
|
||||||
|
|
||||||
.. automodule:: colossalai.engine.amp_type
|
|
||||||
:members:
|
|
|
@ -7,11 +7,6 @@ colossalai.engine
|
||||||
.. toctree::
|
.. toctree::
|
||||||
:maxdepth: 2
|
:maxdepth: 2
|
||||||
|
|
||||||
|
colossalai.engine.amp
|
||||||
colossalai.engine.gradient_handler
|
colossalai.engine.gradient_handler
|
||||||
colossalai.engine.schedule
|
colossalai.engine.schedule
|
||||||
|
|
||||||
|
|
||||||
.. toctree::
|
|
||||||
:maxdepth: 2
|
|
||||||
|
|
||||||
colossalai.engine.amp_type
|
|
||||||
|
|
|
@ -21,7 +21,6 @@ colossalai
|
||||||
.. toctree::
|
.. toctree::
|
||||||
:maxdepth: 2
|
:maxdepth: 2
|
||||||
|
|
||||||
colossalai.checkpointing
|
|
||||||
colossalai.constants
|
colossalai.constants
|
||||||
colossalai.core
|
colossalai.core
|
||||||
colossalai.initialize
|
colossalai.initialize
|
||||||
|
|
|
@ -0,0 +1,5 @@
|
||||||
|
colossalai.utils.checkpointing
|
||||||
|
==============================
|
||||||
|
|
||||||
|
.. automodule:: colossalai.utils.checkpointing
|
||||||
|
:members:
|
|
@ -9,6 +9,7 @@ colossalai.utils
|
||||||
:maxdepth: 2
|
:maxdepth: 2
|
||||||
|
|
||||||
colossalai.utils.activation_checkpoint
|
colossalai.utils.activation_checkpoint
|
||||||
|
colossalai.utils.checkpointing
|
||||||
colossalai.utils.common
|
colossalai.utils.common
|
||||||
colossalai.utils.cuda
|
colossalai.utils.cuda
|
||||||
colossalai.utils.memory
|
colossalai.utils.memory
|
||||||
|
|
|
@ -17,38 +17,40 @@ parallel = dict(
|
||||||
)
|
)
|
||||||
```
|
```
|
||||||
|
|
||||||
The name of the dictionary variable should be **parallel**. All the arguments even **parallel** itself are optional and data,
|
The name of the dictionary variable should be **parallel**. All the arguments even **parallel** itself are optional and
|
||||||
pipeline, tensor parallel size will be set to defaulted value 1. The value of data, pipeline and tensor can be a int
|
data, pipeline, tensor parallel size will be set to defaulted value 1. The value of data, pipeline and tensor can be a
|
||||||
representing the size of specific parallel dimension or a dictionary with a key called "size". The key "mode"
|
int representing the size of specific parallel dimension or a dictionary with a key called "size". The key "mode"
|
||||||
represents the way of tensor parallelism.
|
represents the way of tensor parallelism.
|
||||||
|
|
||||||
## Data Parallel
|
## Data Parallel
|
||||||
|
|
||||||
Data parallel is the most common way to distribute your training task by splitting data into several shards and train
|
Data parallel is the most common way to distribute your training task by splitting data into several shards and train on
|
||||||
on a single shard on each device. The configuration for data parallel is detected automatically and set for you. You do
|
a single shard on each device. The configuration for data parallel is detected automatically and set for you. You do not
|
||||||
not have to explicitly set them in your configurations. When data parallel size is larger than 1, Colossal-AI automatically
|
have to explicitly set them in your configurations. When data parallel size is larger than 1, Colossal-AI automatically
|
||||||
adds the distributed data sampler to the dataloader to shard the dataset.
|
adds the distributed data sampler to the dataloader to shard the dataset.
|
||||||
|
|
||||||
## 1D, 2D, 2.5D and 3D Parallel
|
## 1D, 2D, 2.5D and 3D Parallel
|
||||||
|
|
||||||
To enable hybrid parallelism, we provide an array of tensor parallelism. We provide the list of papers which match each
|
To enable hybrid parallelism, we provide an array of tensor parallelism. We provide the list of papers which match each
|
||||||
tensor parallel method. These parallel modes need to work with the distributed layers provided by Colossal-AI.
|
tensor parallel method. These parallel modes need to work with the distributed layers provided by Colossal-AI.
|
||||||
- 1D: [Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism](https://arxiv.org/abs/1909.08053)
|
|
||||||
|
-
|
||||||
|
1D: [Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism](https://arxiv.org/abs/1909.08053)
|
||||||
|
|
||||||
- 2D: [An Efficient 2D Method for Training Super-Large Deep Learning Models](https://arxiv.org/abs/2104.05343)
|
- 2D: [An Efficient 2D Method for Training Super-Large Deep Learning Models](https://arxiv.org/abs/2104.05343)
|
||||||
2D parallel relies on the SUMMA matrix multiplication algorithm and splits the input data,
|
2D parallel relies on the SUMMA matrix multiplication algorithm and splits the input data, model weights and layer
|
||||||
model weights and layer outputs along two different dimensions. The tensor chunks are distributed over a 2D mesh of $P = N^2$
|
outputs along two different dimensions. The tensor chunks are distributed over a 2D mesh of $P = N^2$ devices where
|
||||||
devices where $N$ is the number of tensor chunks in a single dimension.
|
$N$ is the number of tensor chunks in a single dimension.
|
||||||
|
|
||||||
- 2.5D: [2.5-dimensional distributed model training](https://arxiv.org/abs/2105.14500)
|
- 2.5D: [2.5-dimensional distributed model training](https://arxiv.org/abs/2105.14500)
|
||||||
Inspired by the 2.5D matrix multiplication algorithm, 2.5D parallel introduces a novel tensor parallelism which further
|
Inspired by the 2.5D matrix multiplication algorithm, 2.5D parallel introduces a novel tensor parallelism which
|
||||||
parallelizes 2D tensor parallelism. An amount of $P = N^2 ∗ d$ processors are arranged into $d$ layers,
|
further parallelizes 2D tensor parallelism. An amount of $P = N^2 ∗ d$ processors are arranged into $d$ layers, where
|
||||||
where each layer performs matrix multiplication operations independently with a dimension $N$.
|
each layer performs matrix multiplication operations independently with a dimension $N$.
|
||||||
|
|
||||||
- 3D: [Maximizing Parallelism in Distributed Training for Huge Neural Networks](https://arxiv.org/abs/2105.14450)
|
- 3D: [Maximizing Parallelism in Distributed Training for Huge Neural Networks](https://arxiv.org/abs/2105.14450)
|
||||||
We also introduce a 3D tensor parallelism that parallelizes neural networks on a 3D processor cube. This method achieves
|
We also introduce a 3D tensor parallelism that parallelizes neural networks on a 3D processor cube. This method
|
||||||
the optimal, $O(P^{1/3})$ communication overhead on $P$ processors, while both computation and memory usage are evenly distributed
|
achieves the optimal, $O(P^{1/3})$ communication overhead on $P$ processors, while both computation and memory usage
|
||||||
through optimized load balancing of parameters as well as activations.
|
are evenly distributed through optimized load balancing of parameters as well as activations.
|
||||||
|
|
||||||
```python
|
```python
|
||||||
# 1D parallel
|
# 1D parallel
|
||||||
|
@ -193,8 +195,8 @@ class VanillaResNet(BaseModel):
|
||||||
```
|
```
|
||||||
|
|
||||||
You can set the number of pipeline stages in your configuration file. When pipeline size is larger than 1, Colossal-AI
|
You can set the number of pipeline stages in your configuration file. When pipeline size is larger than 1, Colossal-AI
|
||||||
will automatically creates the pipeline schedule which defines the forward and backward step. You can specify how many microbatches
|
will automatically creates the pipeline schedule which defines the forward and backward step. You can specify how many
|
||||||
to run in each step in the `schedule` configuration.
|
microbatches to run in each step in the `schedule` configuration.
|
||||||
|
|
||||||
```python
|
```python
|
||||||
parallel = dict(
|
parallel = dict(
|
||||||
|
@ -206,6 +208,7 @@ schedule = dict(
|
||||||
num_microbatches = 4 # set the number of microbatches per step
|
num_microbatches = 4 # set the number of microbatches per step
|
||||||
)
|
)
|
||||||
```
|
```
|
||||||
|
|
||||||
This feature is still in development and is only experimental for now.
|
This feature is still in development and is only experimental for now.
|
||||||
|
|
||||||
## Sequence Parallel (experimental)
|
## Sequence Parallel (experimental)
|
||||||
|
|
|
@ -1,8 +1,8 @@
|
||||||
# Quick demo
|
# Quick demo
|
||||||
|
|
||||||
Colossal-AI is an integrated large-scale deep learning system with efficient parallelization techniques. The system
|
Colossal-AI is an integrated large-scale deep learning system with efficient parallelization techniques. The system can
|
||||||
can accelerate model training on distributed systems with multiple GPUs by applying parallelization techniques. The
|
accelerate model training on distributed systems with multiple GPUs by applying parallelization techniques. The system
|
||||||
system can also run on systems with only one GPU. Quick demos showing how to use Colossal-AI are given below.
|
can also run on systems with only one GPU. Quick demos showing how to use Colossal-AI are given below.
|
||||||
|
|
||||||
## Single GPU
|
## Single GPU
|
||||||
|
|
||||||
|
@ -32,25 +32,17 @@ realizes the training process.
|
||||||
```python
|
```python
|
||||||
import colossalai
|
import colossalai
|
||||||
from colossalai.core import global_context as gpc
|
from colossalai.core import global_context as gpc
|
||||||
from colossalai.engine import Engine
|
|
||||||
from colossalai.logging import get_global_dist_logger
|
from colossalai.logging import get_global_dist_logger
|
||||||
from colossalai.trainer import Trainer
|
from colossalai.trainer import Trainer
|
||||||
|
|
||||||
|
|
||||||
def run_trainer():
|
def run_trainer():
|
||||||
model, train_dataloader, test_dataloader, criterion, optimizer, schedule, lr_scheduler = colossalai.initialize()
|
engine, train_dataloader, test_dataloader = colossalai.initialize()
|
||||||
logger = get_global_dist_logger()
|
logger = get_global_dist_logger()
|
||||||
schedule.data_sync = False
|
|
||||||
engine = Engine(
|
|
||||||
model=model,
|
|
||||||
criterion=criterion,
|
|
||||||
optimizer=optimizer,
|
|
||||||
lr_scheduler=lr_scheduler,
|
|
||||||
schedule=schedule
|
|
||||||
)
|
|
||||||
logger.info("engine is built", ranks=[0])
|
logger.info("engine is built", ranks=[0])
|
||||||
|
|
||||||
trainer = Trainer(engine=engine,
|
trainer = Trainer(engine=engine,
|
||||||
hooks_cfg=gpc.config.hooks,
|
|
||||||
verbose=True)
|
verbose=True)
|
||||||
logger.info("trainer is built", ranks=[0])
|
logger.info("trainer is built", ranks=[0])
|
||||||
|
|
||||||
|
@ -58,11 +50,13 @@ def run_trainer():
|
||||||
trainer.fit(
|
trainer.fit(
|
||||||
train_dataloader=train_dataloader,
|
train_dataloader=train_dataloader,
|
||||||
test_dataloader=test_dataloader,
|
test_dataloader=test_dataloader,
|
||||||
max_epochs=gpc.config.num_epochs,
|
epochs=gpc.config.num_epochs,
|
||||||
|
hooks_cfg=gpc.config.hooks,
|
||||||
display_progress=True,
|
display_progress=True,
|
||||||
test_interval=2
|
test_interval=2
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
run_trainer()
|
run_trainer()
|
||||||
```
|
```
|
||||||
|
@ -72,9 +66,9 @@ Zoo. The detailed substitution process is elaborated [here](model.md).
|
||||||
|
|
||||||
## Features
|
## Features
|
||||||
|
|
||||||
Colossal-AI provides a collection of parallel training components for you. We aim to support you with your development of
|
Colossal-AI provides a collection of parallel training components for you. We aim to support you with your development
|
||||||
distributed deep learning models just like how you write single-GPU deep learning models. We provide friendly tools to
|
of distributed deep learning models just like how you write single-GPU deep learning models. We provide friendly tools
|
||||||
kickstart distributed training in a few lines.
|
to kickstart distributed training in a few lines.
|
||||||
|
|
||||||
- [Data Parallelism](parallelization.md)
|
- [Data Parallelism](parallelization.md)
|
||||||
- [Pipeline Parallelism](parallelization.md)
|
- [Pipeline Parallelism](parallelization.md)
|
||||||
|
|
|
@ -4,40 +4,36 @@ Colossal-AI是一个大规模深度学习系统,其中包含高效的并行技
|
||||||
|
|
||||||
## 单GPU系统
|
## 单GPU系统
|
||||||
|
|
||||||
在带有GPU的非分布式系统上进行模型训练时,Colossal-AI可以达到当前的基线效率。[这里](https://colab.research.google.com/drive/1fJnqqFzPuzZ_kn1lwCpG2nh3l2ths0KE?usp=sharing#scrollTo=cQ_y7lBG09LS)我们给出一个Google Colab示例展现如何使用Colossal-AI与CIFAR10数据集在非分布式系统上训练一个LeNet模型。
|
在带有GPU的非分布式系统上进行模型训练时,Colossal-AI可以达到当前的基线效率。[这里](https://colab.research.google.com/drive/1fJnqqFzPuzZ_kn1lwCpG2nh3l2ths0KE?usp=sharing#scrollTo=cQ_y7lBG09LS)我们给出一个Google
|
||||||
|
Colab示例展现如何使用Colossal-AI与CIFAR10数据集在非分布式系统上训练一个LeNet模型。
|
||||||
|
|
||||||
## 多GPU系统
|
## 多GPU系统
|
||||||
|
|
||||||
在多GPU的分布式系统上训练深度学习模型时,Colossal-AI可以使用高效的并行技术来显著地加速训练过程,这些技术将在下面的[并行技术](parallelization.md)章节中被详述。下面的代码将在拥有四个GPU的分布式系统上训练一个ViT模型,其中`HOST`变量为您分布式系统的IP地址。请注意下面的代码使用了[Slurm](https://slurm.schedmd.com/documentation.html)作业调度系统。
|
在多GPU的分布式系统上训练深度学习模型时,Colossal-AI可以使用高效的并行技术来显著地加速训练过程,这些技术将在下面的[并行技术](parallelization.md)
|
||||||
|
章节中被详述。下面的代码将在拥有四个GPU的分布式系统上训练一个ViT模型,其中`HOST`
|
||||||
|
变量为您分布式系统的IP地址。请注意下面的代码使用了[Slurm](https://slurm.schedmd.com/documentation.html)作业调度系统。
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
HOST=xxx.xxx.xxx.xxx srun ./scripts/slurm_dist_train.sh ./examples/run_trainer.py ./configs/vit/vit_2d.py
|
HOST=xxx.xxx.xxx.xxx srun ./scripts/slurm_dist_train.sh ./examples/run_trainer.py ./configs/vit/vit_2d.py
|
||||||
```
|
```
|
||||||
|
|
||||||
`./configs/vit/vit_2d.py`是一个[配置文件](config.md),Colossal-AI使用配置文件来定义训练过程中需要用到的参数,比如模型类型、数据集、以及优化器、学习率调度器等。您可以通过编写配置文件的方式来训练不同的模型。`./examples/run_trainer.py`是一个标准的训练脚本,具体代码已经附在下面。该脚本可以读入配置文件中的训练参数并训练模型。
|
`./configs/vit/vit_2d.py`是一个[配置文件](config.md)
|
||||||
|
,Colossal-AI使用配置文件来定义训练过程中需要用到的参数,比如模型类型、数据集、以及优化器、学习率调度器等。您可以通过编写配置文件的方式来训练不同的模型。`./examples/run_trainer.py`
|
||||||
|
是一个标准的训练脚本,具体代码已经附在下面。该脚本可以读入配置文件中的训练参数并训练模型。
|
||||||
|
|
||||||
```python
|
```python
|
||||||
import colossalai
|
import colossalai
|
||||||
from colossalai.core import global_context as gpc
|
from colossalai.core import global_context as gpc
|
||||||
from colossalai.engine import Engine
|
|
||||||
from colossalai.logging import get_global_dist_logger
|
from colossalai.logging import get_global_dist_logger
|
||||||
from colossalai.trainer import Trainer
|
from colossalai.trainer import Trainer
|
||||||
|
|
||||||
|
|
||||||
def run_trainer():
|
def run_trainer():
|
||||||
model, train_dataloader, test_dataloader, criterion, optimizer, schedule, lr_scheduler = colossalai.initialize()
|
engine, train_dataloader, test_dataloader = colossalai.initialize()
|
||||||
logger = get_global_dist_logger()
|
logger = get_global_dist_logger()
|
||||||
schedule.data_sync = False
|
|
||||||
engine = Engine(
|
|
||||||
model=model,
|
|
||||||
criterion=criterion,
|
|
||||||
optimizer=optimizer,
|
|
||||||
lr_scheduler=lr_scheduler,
|
|
||||||
schedule=schedule
|
|
||||||
)
|
|
||||||
logger.info("engine is built", ranks=[0])
|
logger.info("engine is built", ranks=[0])
|
||||||
|
|
||||||
trainer = Trainer(engine=engine,
|
trainer = Trainer(engine=engine,
|
||||||
hooks_cfg=gpc.config.hooks,
|
|
||||||
verbose=True)
|
verbose=True)
|
||||||
logger.info("trainer is built", ranks=[0])
|
logger.info("trainer is built", ranks=[0])
|
||||||
|
|
||||||
|
@ -45,11 +41,13 @@ def run_trainer():
|
||||||
trainer.fit(
|
trainer.fit(
|
||||||
train_dataloader=train_dataloader,
|
train_dataloader=train_dataloader,
|
||||||
test_dataloader=test_dataloader,
|
test_dataloader=test_dataloader,
|
||||||
max_epochs=gpc.config.num_epochs,
|
epochs=gpc.config.num_epochs,
|
||||||
|
hooks_cfg=gpc.config.hooks,
|
||||||
display_progress=True,
|
display_progress=True,
|
||||||
test_interval=2
|
test_interval=2
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
run_trainer()
|
run_trainer()
|
||||||
```
|
```
|
||||||
|
|
|
@ -2,9 +2,9 @@
|
||||||
|
|
||||||
## Build your engine
|
## Build your engine
|
||||||
|
|
||||||
To better understand how `Engine` class works, let's start from the conception of the process function in common engines. The process function
|
To better understand how `Engine` class works, let's start from the conception of the process function in common
|
||||||
usually controls the behavior over a batch of a dataset, `Engine` class just controls the process function. Here we give a standard process
|
engines. The process function usually controls the behavior over a batch of a dataset, `Engine` class just controls the
|
||||||
function in the following code block.
|
process function. Here we give a standard process function in the following code block.
|
||||||
|
|
||||||
```python
|
```python
|
||||||
def process_function(dataloader, model, criterion, optim):
|
def process_function(dataloader, model, criterion, optim):
|
||||||
|
@ -16,32 +16,33 @@ def process_function(dataloader, model, criterion, optim):
|
||||||
optim.setp()
|
optim.setp()
|
||||||
```
|
```
|
||||||
|
|
||||||
In `ignite.engine` or `keras.engine`, the process function is always provided by users. However, it is tricky for users to write their own process
|
In `ignite.engine` or `keras.engine`, the process function is always provided by users. However, it is tricky for users
|
||||||
functions for pipeline parallelism. Aiming at offering accessible hybrid parallelism for users, we provide the powerful `Engine` class. This class
|
to write their own process functions for pipeline parallelism. Aiming at offering accessible hybrid parallelism for
|
||||||
enables pipeline parallelism and offers one-forward-one-backward non-interleaving strategy. Also, you can use pre-defined learning rate scheduler
|
users, we provide the powerful `Engine` class. This class enables pipeline parallelism and offers
|
||||||
in the `Engine` class to adjust learning rate during training.
|
one-forward-one-backward non-interleaving strategy. Also, you can use pre-defined learning rate scheduler in
|
||||||
|
the `Engine` class to adjust learning rate during training.
|
||||||
|
|
||||||
In order to build your engine, just set variables `model`, `criterion`, `optimizer`, `lr_scheduler` and `schedule`. The following code block provides
|
In order to build your engine, just set variables `model`, `criterion`, `optimizer`, `lr_scheduler` and `schedule`. The
|
||||||
an example.
|
following code block provides an example. **The engine is automatically created from the config file for you if you
|
||||||
|
start with `colossalai.initialize`.**
|
||||||
|
|
||||||
```python
|
```python
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torchvision.models as models
|
import torchvision.models as models
|
||||||
import colossalai
|
import colossalai
|
||||||
|
from colossalai.engine import Engine
|
||||||
|
|
||||||
model = models.resnet18()
|
model = models.resnet18()
|
||||||
criterion = nn.CrossEntropyLoss()
|
criterion = nn.CrossEntropyLoss()
|
||||||
optimizer = torch.optim.Adam(model)
|
optimizer = torch.optim.Adam(model.parameters())
|
||||||
lr_scheduler = colossalai.nn.lr_scheduler.CosineAnnealingLR(optimizer, 1000)
|
schedule = colossalai.engine.NoPipelineSchedule()
|
||||||
schedule = colossalai.engine.schedule.NoPipelineSchedule()
|
|
||||||
|
|
||||||
MyEngine = Engine(
|
MyEngine = Engine(
|
||||||
model=model,
|
model=model,
|
||||||
criterion=criterion,
|
criterion=criterion,
|
||||||
optimizer=optimizer,
|
optimizer=optimizer,
|
||||||
lr_scheduler=lr_scheduler,
|
step_schedule=schedule
|
||||||
schedule=schedule
|
|
||||||
)
|
)
|
||||||
```
|
```
|
||||||
|
|
||||||
|
@ -51,21 +52,24 @@ More information regarding the class can be found in the API references.
|
||||||
|
|
||||||
### Overview
|
### Overview
|
||||||
|
|
||||||
To learn how to customize a trainer which meets your needs, let's first give a look at the `Trainer` class. We highly recommend that you read *Get Started*
|
To learn how to customize a trainer which meets your needs, let's first give a look at the `Trainer` class. We highly
|
||||||
|
recommend that you read *Get Started*
|
||||||
section and *Build your engine* first.
|
section and *Build your engine* first.
|
||||||
|
|
||||||
The `Trainer` class enables researchers and engineers to use our system more conveniently. Instead of having to write your own scripts, you can simply
|
The `Trainer` class enables researchers and engineers to use our system more conveniently. Instead of having to write
|
||||||
construct your own trainer by calling the `Trainer` class, just like what we did in the following code block.
|
your own scripts, you can simply construct your own trainer by calling the `Trainer` class, just like what we did in the
|
||||||
|
following code block.
|
||||||
|
|
||||||
```python
|
```python
|
||||||
MyTrainer = Trainer(MyEngine)
|
MyTrainer = Trainer(my_engine)
|
||||||
```
|
```
|
||||||
|
|
||||||
After that, you can use the `fit` method to train or evaluate your model. In order to make our `Trainer` class even more powerful, we incorporate a set of
|
After that, you can use the `fit` method to train or evaluate your model. In order to make our `Trainer` class even more
|
||||||
handy tools to the class. For example, you can monitor or record the running states and metrics which indicate the current performance of the model. These
|
powerful, we incorporate a set of handy tools to the class. For example, you can monitor or record the running states
|
||||||
functions are realized by hooks. The `BasicHook` class allows you to execute your hook functions at specified time. We have already created some practical
|
and metrics which indicate the current performance of the model. These functions are realized by hooks. The `BasicHook`
|
||||||
hooks for you, as listed below. What you need to do is just picking the right ones which suit your needs. Detailed descriptions of the class can be found
|
class allows you to execute your hook functions at specified time. We have already created some practical hooks for you,
|
||||||
in the API references.
|
as listed below. What you need to do is just picking the right ones which suit your needs. Detailed descriptions of the
|
||||||
|
class can be found in the API references.
|
||||||
|
|
||||||
```python
|
```python
|
||||||
hooks = [
|
hooks = [
|
||||||
|
@ -80,18 +84,21 @@ hooks = [
|
||||||
]
|
]
|
||||||
```
|
```
|
||||||
|
|
||||||
These hook functions will record metrics, elapsed time and memory usage and write them to log after each epoch. Besides, they print the current loss and
|
These hook functions will record metrics, elapsed time and memory usage and write them to log after each epoch. Besides,
|
||||||
accuracy to let users monitor the performance of the model.
|
they print the current loss and accuracy to let users monitor the performance of the model.
|
||||||
|
|
||||||
### Hook
|
### Hook
|
||||||
|
|
||||||
If you have your specific needs, feel free to extend our `BaseHook` class to add your own functions, or our `MetricHook` class to write a metric collector.
|
If you have your specific needs, feel free to extend our `BaseHook` class to add your own functions, or our `MetricHook`
|
||||||
These hook functions can be called at twelve timing in the trainer's life cycle. Besides, you can define the priorities of all hooks to arrange the execution order of them.
|
class to write a metric collector. These hook functions can be called at twelve timing in the trainer's life cycle.
|
||||||
More information can be found in the API references.
|
Besides, you can define the priorities of all hooks to arrange the execution order of them. More information can be
|
||||||
|
found in the API references.
|
||||||
|
|
||||||
### Metric
|
### Metric
|
||||||
|
|
||||||
You can write your own metrics by extending our `Metric` class. It should be used with the `MetricHook` class. When your write your own metric hooks, please set
|
You can write your own metrics by extending our `Metric` class. It should be used with the `MetricHook` class. When your
|
||||||
the priority carefully and make sure the hook is called before other hooks which might require the results of the metric hook.
|
write your own metric hooks, please set the priority carefully and make sure the hook is called before other hooks which
|
||||||
|
might require the results of the metric hook.
|
||||||
|
|
||||||
We've already provided some metric hooks and we store metric objects in `runner.states['metrics']`. It is a dictionary and metrics can be accessed by their names.
|
We've already provided some metric hooks and we store metric objects in `runner.states['metrics']`. It is a dictionary
|
||||||
|
and metrics can be accessed by their names.
|
||||||
|
|
|
@ -14,28 +14,30 @@ def process_function(dataloader, model, criterion, optim):
|
||||||
optim.setp()
|
optim.setp()
|
||||||
```
|
```
|
||||||
|
|
||||||
在`ignite.engine`与`keras.engine`中,进程函数需要由用户提供,然而,用户很难为流水线并行编写进程函数。为了向用户提供方便的混合并行,我们提供了具备强大功能的`Engine`类,该类支持流水线并行,并提供前向传播后向传播不交织的策略。同时,您可以在`Engine`类中使用您事先定义好的学习率调度器来在训练过程中调整学习率。
|
在`ignite.engine`与`keras.engine`中,进程函数需要由用户提供,然而,用户很难为流水线并行编写进程函数。为了向用户提供方便的混合并行,我们提供了具备强大功能的`Engine`
|
||||||
|
类,该类支持流水线并行,并提供前向传播后向传播不交织的策略。同时,您可以在`Engine`类中使用您事先定义好的学习率调度器来在训练过程中调整学习率。
|
||||||
|
|
||||||
您在构造引擎时只需要定义`model`、`criterion`、`optimizer`、`lr_scheduler`与`schedule`等变量即可,下面的代码块给出了一个这样的例子。
|
您在构造引擎时只需要定义`model`、`criterion`、`optimizer`、`lr_scheduler`与`schedule`等变量即可,下面的代码块给出了一个这样的例子。
|
||||||
|
**如果你使用`colossalai.initialize`的话,engine会从config文件里自动构建。**
|
||||||
|
|
||||||
```python
|
```python
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torchvision.models as models
|
import torchvision.models as models
|
||||||
import colossalai
|
import colossalai
|
||||||
|
from colossalai.engine import Engine
|
||||||
|
|
||||||
model = models.resnet18()
|
model = models.resnet18()
|
||||||
criterion = nn.CrossEntropyLoss()
|
criterion = nn.CrossEntropyLoss()
|
||||||
optimizer = torch.optim.Adam(model)
|
optimizer = torch.optim.Adam(model)
|
||||||
lr_scheduler = colossalai.nn.lr_scheduler.CosineAnnealingLR(optimizer, 1000)
|
lr_scheduler = colossalai.nn.lr_scheduler.CosineAnnealingLR(optimizer, 1000)
|
||||||
schedule = colossalai.engine.schedule.NoPipelineSchedule()
|
schedule = colossalai.engine.NoPipelineSchedule()
|
||||||
|
|
||||||
MyEngine = Engine(
|
MyEngine = Engine(
|
||||||
model=model,
|
model=model,
|
||||||
criterion=criterion,
|
criterion=criterion,
|
||||||
optimizer=optimizer,
|
optimizer=optimizer,
|
||||||
lr_scheduler=lr_scheduler,
|
step_schedule=schedule
|
||||||
schedule=schedule
|
|
||||||
)
|
)
|
||||||
```
|
```
|
||||||
|
|
||||||
|
@ -48,10 +50,12 @@ MyEngine = Engine(
|
||||||
`Trainer`类旨在让科研工作者和工程师更加方便地使用我们的系统,您不需要自己写脚本,只需要调用`Trainer`类来构造您的训练器即可,就像下面的代码块中所做的。
|
`Trainer`类旨在让科研工作者和工程师更加方便地使用我们的系统,您不需要自己写脚本,只需要调用`Trainer`类来构造您的训练器即可,就像下面的代码块中所做的。
|
||||||
|
|
||||||
```python
|
```python
|
||||||
MyTrainer = Trainer(MyEngine)
|
MyTrainer = Trainer(my_trainer)
|
||||||
```
|
```
|
||||||
|
|
||||||
在此之后,您可以使用`fit`方法来训练或调用您的模型。除此之外,为了让我们的`Trainer`类拥有更强大的功能,我们加入了一系列方便您使用的工具。例如,您可以在训练过程中持续监测并记录模型目前的运行状态和表现,这些功能都是通过钩子函数来实现的。我们提供的`BasicHook`类让您可以在指定时间执行您的钩子函数。如下方的代码块所示,我们事先为您定义好了一些实用的钩子函数,您需要做的就是找到符合您需求的钩子函数。更多该类的相关信息可以在API信息中找到。
|
在此之后,您可以使用`fit`方法来训练或调用您的模型。除此之外,为了让我们的`Trainer`
|
||||||
|
类拥有更强大的功能,我们加入了一系列方便您使用的工具。例如,您可以在训练过程中持续监测并记录模型目前的运行状态和表现,这些功能都是通过钩子函数来实现的。我们提供的`BasicHook`
|
||||||
|
类让您可以在指定时间执行您的钩子函数。如下方的代码块所示,我们事先为您定义好了一些实用的钩子函数,您需要做的就是找到符合您需求的钩子函数。更多该类的相关信息可以在API信息中找到。
|
||||||
|
|
||||||
```python
|
```python
|
||||||
hooks = [
|
hooks = [
|
||||||
|
@ -70,7 +74,8 @@ hooks = [
|
||||||
|
|
||||||
### 钩子函数
|
### 钩子函数
|
||||||
|
|
||||||
如果您有个性化需求,您可以继承我们的`BaseHook`类并添加您的钩子函数,或者继承我们的`MetricHook`来编写您需要的度量标准。这些钩子函数可以在`Trainer`生命周期的12个时间点被执行。更多该类的相关信息可以在API信息中找到。
|
如果您有个性化需求,您可以继承我们的`BaseHook`类并添加您的钩子函数,或者继承我们的`MetricHook`来编写您需要的度量标准。这些钩子函数可以在`Trainer`
|
||||||
|
生命周期的12个时间点被执行。更多该类的相关信息可以在API信息中找到。
|
||||||
|
|
||||||
### 度量标准
|
### 度量标准
|
||||||
|
|
||||||
|
|
|
@ -3,26 +3,18 @@
|
||||||
|
|
||||||
import colossalai
|
import colossalai
|
||||||
from colossalai.core import global_context as gpc
|
from colossalai.core import global_context as gpc
|
||||||
from colossalai.engine import Engine
|
|
||||||
from colossalai.logging import get_global_dist_logger
|
from colossalai.logging import get_global_dist_logger
|
||||||
from colossalai.trainer import Trainer
|
from colossalai.trainer import Trainer
|
||||||
|
|
||||||
|
|
||||||
def run_trainer():
|
def run_trainer():
|
||||||
model, train_dataloader, test_dataloader, criterion, optimizer, schedule, lr_scheduler = colossalai.initialize()
|
engine, train_dataloader, test_dataloader = colossalai.initialize()
|
||||||
logger = get_global_dist_logger()
|
logger = get_global_dist_logger()
|
||||||
schedule.data_sync = False
|
engine.schedule.data_sync = False
|
||||||
engine = Engine(
|
|
||||||
model=model,
|
|
||||||
criterion=criterion,
|
|
||||||
optimizer=optimizer,
|
|
||||||
lr_scheduler=lr_scheduler,
|
|
||||||
schedule=schedule
|
|
||||||
)
|
|
||||||
logger.info("engine is built", ranks=[0])
|
logger.info("engine is built", ranks=[0])
|
||||||
|
|
||||||
trainer = Trainer(engine=engine,
|
trainer = Trainer(engine=engine,
|
||||||
hooks_cfg=gpc.config.hooks,
|
|
||||||
verbose=True)
|
verbose=True)
|
||||||
logger.info("trainer is built", ranks=[0])
|
logger.info("trainer is built", ranks=[0])
|
||||||
|
|
||||||
|
@ -30,7 +22,8 @@ def run_trainer():
|
||||||
trainer.fit(
|
trainer.fit(
|
||||||
train_dataloader=train_dataloader,
|
train_dataloader=train_dataloader,
|
||||||
test_dataloader=test_dataloader,
|
test_dataloader=test_dataloader,
|
||||||
max_epochs=gpc.config.num_epochs,
|
epochs=gpc.config.num_epochs,
|
||||||
|
hooks_cfg=gpc.config.hooks,
|
||||||
display_progress=True,
|
display_progress=True,
|
||||||
test_interval=2
|
test_interval=2
|
||||||
)
|
)
|
||||||
|
|
|
@ -3,5 +3,5 @@ torchvision>=0.9
|
||||||
numpy
|
numpy
|
||||||
tqdm
|
tqdm
|
||||||
psutil
|
psutil
|
||||||
tensorboardX
|
tensorboard
|
||||||
packaging
|
packaging
|
2
setup.py
2
setup.py
|
@ -121,7 +121,7 @@ if "--cuda_ext" in sys.argv:
|
||||||
install_requires = fetch_requirements('requirements/requirements.txt')
|
install_requires = fetch_requirements('requirements/requirements.txt')
|
||||||
|
|
||||||
setup(
|
setup(
|
||||||
name='colossal-ai',
|
name='colossalai',
|
||||||
version='0.0.1-beta',
|
version='0.0.1-beta',
|
||||||
packages=find_packages(exclude=('csrc',
|
packages=find_packages(exclude=('csrc',
|
||||||
'tests',
|
'tests',
|
||||||
|
|
|
@ -27,8 +27,6 @@ train_data = dict(
|
||||||
dataloader=dict(
|
dataloader=dict(
|
||||||
batch_size=BATCH_SIZE,
|
batch_size=BATCH_SIZE,
|
||||||
pin_memory=True,
|
pin_memory=True,
|
||||||
# num_workers=1,
|
|
||||||
# shuffle=True,
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -63,14 +61,6 @@ loss = dict(
|
||||||
type='CrossEntropyLoss2D',
|
type='CrossEntropyLoss2D',
|
||||||
)
|
)
|
||||||
|
|
||||||
# model = dict(
|
|
||||||
# type='VanillaResNet',
|
|
||||||
# block_type='ResNetBasicBlock',
|
|
||||||
# layers=[2, 2, 2, 2],
|
|
||||||
# num_cls=10
|
|
||||||
# )
|
|
||||||
|
|
||||||
|
|
||||||
model = dict(
|
model = dict(
|
||||||
type='VisionTransformerFromConfig',
|
type='VisionTransformerFromConfig',
|
||||||
tensor_splitting_cfg=dict(
|
tensor_splitting_cfg=dict(
|
||||||
|
@ -135,25 +125,26 @@ parallel = dict(
|
||||||
|
|
||||||
fp16 = dict(
|
fp16 = dict(
|
||||||
mode=AMP_TYPE.PARALLEL,
|
mode=AMP_TYPE.PARALLEL,
|
||||||
initial_scale=2 ** 8
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# fp16 = dict(
|
engine = dict(
|
||||||
# mode=None,
|
schedule=dict(
|
||||||
# )
|
|
||||||
|
|
||||||
schedule = dict(
|
|
||||||
num_microbatches=2
|
num_microbatches=2
|
||||||
)
|
)
|
||||||
lr_scheduler = dict(
|
|
||||||
type='LinearWarmupLR',
|
|
||||||
warmup_epochs=5
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
hooks = [
|
||||||
|
dict(
|
||||||
|
type='LRSchedulerHook',
|
||||||
|
by_epoch=True,
|
||||||
|
lr_scheduler_cfg=dict(
|
||||||
|
type='LinearWarmupLR',
|
||||||
|
warmup_steps=5
|
||||||
|
)
|
||||||
|
),
|
||||||
|
]
|
||||||
num_epochs = 60
|
num_epochs = 60
|
||||||
|
|
||||||
logging = dict(
|
logging = dict(
|
||||||
root_path='test_vit_2d_log'
|
root_path='test_vit_2d_log'
|
||||||
)
|
)
|
||||||
|
|
||||||
seed = 100
|
|
||||||
|
|
|
@ -124,14 +124,21 @@ parallel = dict(
|
||||||
tensor=dict(size=4, depth=1, mode='2.5d'),
|
tensor=dict(size=4, depth=1, mode='2.5d'),
|
||||||
)
|
)
|
||||||
|
|
||||||
lr_scheduler = dict(
|
hooks = [
|
||||||
|
dict(
|
||||||
|
type='LRSchedulerHook',
|
||||||
|
by_epoch=True,
|
||||||
|
lr_scheduler_cfg=dict(
|
||||||
type='LinearWarmupLR',
|
type='LinearWarmupLR',
|
||||||
warmup_epochs=5
|
warmup_steps=5
|
||||||
)
|
)
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
engine = dict(
|
||||||
schedule = dict(
|
schedule = dict(
|
||||||
num_microbatches=2
|
num_microbatches=2
|
||||||
)
|
)
|
||||||
|
)
|
||||||
|
|
||||||
num_epochs = 60
|
num_epochs = 60
|
||||||
num_microbatches = 1
|
|
||||||
|
|
|
@ -9,21 +9,22 @@ import torch.autograd
|
||||||
import colossalai
|
import colossalai
|
||||||
from colossalai.context.parallel_mode import ParallelMode
|
from colossalai.context.parallel_mode import ParallelMode
|
||||||
from colossalai.core import global_context as gpc
|
from colossalai.core import global_context as gpc
|
||||||
from colossalai.engine import Engine
|
|
||||||
from colossalai.logging import get_global_dist_logger
|
from colossalai.logging import get_global_dist_logger
|
||||||
from colossalai.nn.layer._parallel_utilities import _gather
|
from colossalai.nn.layer._parallel_utilities import _gather
|
||||||
|
|
||||||
CONFIG_PATH = Path(__file__).parent.parent.joinpath('configs/vit_2d.py')
|
CONFIG_PATH = Path(__file__).parent.parent.joinpath('configs/vit_2d.py')
|
||||||
|
|
||||||
|
|
||||||
def eval(engine):
|
def eval(engine, test_dataloader):
|
||||||
engine.eval()
|
engine.eval()
|
||||||
accumulated_loss = 0
|
accumulated_loss = 0
|
||||||
correct_sum = 0
|
correct_sum = 0
|
||||||
total_sum = 0
|
total_sum = 0
|
||||||
|
num_steps = len(test_dataloader)
|
||||||
|
data_iter = iter(test_dataloader)
|
||||||
|
|
||||||
for i in range(engine.schedule.num_steps):
|
for i in range(num_steps):
|
||||||
output, label, loss = engine.step()
|
output, label, loss = engine.step(data_iter)
|
||||||
|
|
||||||
if gpc.is_last_rank(ParallelMode.PIPELINE):
|
if gpc.is_last_rank(ParallelMode.PIPELINE):
|
||||||
# loss = sum(loss)
|
# loss = sum(loss)
|
||||||
|
@ -43,20 +44,22 @@ def eval(engine):
|
||||||
correct = torch.sum(label == output)
|
correct = torch.sum(label == output)
|
||||||
correct_sum += correct
|
correct_sum += correct
|
||||||
total_sum += label.size(0)
|
total_sum += label.size(0)
|
||||||
avg_loss = accumulated_loss / engine.schedule.num_steps
|
avg_loss = accumulated_loss / num_steps
|
||||||
return correct_sum, total_sum, avg_loss
|
return correct_sum, total_sum, avg_loss
|
||||||
|
|
||||||
|
|
||||||
def train(engine):
|
def train(engine, train_dataloader):
|
||||||
engine.train()
|
engine.train()
|
||||||
accumulated_loss = 0
|
accumulated_loss = 0
|
||||||
|
num_steps = len(train_dataloader)
|
||||||
|
data_iter = iter(train_dataloader)
|
||||||
|
|
||||||
for i in range(engine.schedule.num_steps):
|
for i in range(num_steps):
|
||||||
output, label, loss = engine.step()
|
output, label, loss = engine.step(data_iter)
|
||||||
|
|
||||||
if gpc.is_last_rank(ParallelMode.PIPELINE):
|
if gpc.is_last_rank(ParallelMode.PIPELINE):
|
||||||
accumulated_loss += loss.detach().cpu().numpy()
|
accumulated_loss += loss.detach().cpu().numpy()
|
||||||
avg_loss = accumulated_loss / engine.schedule.num_steps
|
avg_loss = accumulated_loss / num_steps
|
||||||
return avg_loss
|
return avg_loss
|
||||||
|
|
||||||
|
|
||||||
|
@ -64,25 +67,16 @@ def train(engine):
|
||||||
@pytest.mark.skip("This test should be invoked by test.sh in the same folder as it runs on multiple gpus")
|
@pytest.mark.skip("This test should be invoked by test.sh in the same folder as it runs on multiple gpus")
|
||||||
def test_2d_parallel_vision_transformer():
|
def test_2d_parallel_vision_transformer():
|
||||||
# init dist
|
# init dist
|
||||||
model, train_dataloader, test_dataloader, criterion, optimizer, schedule, lr_scheduler = colossalai.initialize(
|
engine, train_dataloader, test_dataloader = colossalai.initialize(CONFIG_PATH)
|
||||||
CONFIG_PATH)
|
|
||||||
logger = get_global_dist_logger()
|
logger = get_global_dist_logger()
|
||||||
|
|
||||||
engine = Engine(model=model,
|
|
||||||
train_dataloader=train_dataloader,
|
|
||||||
test_dataloader=test_dataloader,
|
|
||||||
criterion=criterion,
|
|
||||||
optimizer=optimizer,
|
|
||||||
lr_scheduler=lr_scheduler,
|
|
||||||
schedule=schedule)
|
|
||||||
|
|
||||||
for epoch in range(gpc.config.num_epochs):
|
for epoch in range(gpc.config.num_epochs):
|
||||||
train_loss = train(engine)
|
train_loss = train(engine, train_dataloader)
|
||||||
if gpc.is_last_rank(ParallelMode.PIPELINE):
|
if gpc.is_last_rank(ParallelMode.PIPELINE):
|
||||||
logger.info(f'epoch {epoch} - train loss: {train_loss}')
|
logger.info(f'epoch {epoch} - train loss: {train_loss}')
|
||||||
|
|
||||||
if epoch % 2 == 0:
|
if epoch % 2 == 0:
|
||||||
correct_sum, total_sum, eval_loss = eval(engine)
|
correct_sum, total_sum, eval_loss = eval(engine, test_dataloader)
|
||||||
if gpc.is_last_rank(ParallelMode.PIPELINE):
|
if gpc.is_last_rank(ParallelMode.PIPELINE):
|
||||||
logger.info(
|
logger.info(
|
||||||
f'epoch {epoch} - eval loss: {eval_loss}, total: {total_sum}, '
|
f'epoch {epoch} - eval loss: {eval_loss}, total: {total_sum}, '
|
||||||
|
|
|
@ -6,20 +6,22 @@ import torch.autograd
|
||||||
import colossalai
|
import colossalai
|
||||||
from colossalai.context.parallel_mode import ParallelMode
|
from colossalai.context.parallel_mode import ParallelMode
|
||||||
from colossalai.core import global_context as gpc
|
from colossalai.core import global_context as gpc
|
||||||
from colossalai.engine import Engine
|
|
||||||
from colossalai.logging import get_global_dist_logger
|
from colossalai.logging import get_global_dist_logger
|
||||||
from colossalai.nn.layer._parallel_utilities import _gather
|
from colossalai.nn.layer._parallel_utilities import _gather
|
||||||
|
|
||||||
CONFIG_PATH = Path(__file__).parent.parent.joinpath('configs/vit_2p5d.py')
|
CONFIG_PATH = Path(__file__).parent.parent.joinpath('configs/vit_2p5d.py')
|
||||||
|
|
||||||
def eval(engine):
|
|
||||||
|
def eval(engine, test_dataloader):
|
||||||
engine.eval()
|
engine.eval()
|
||||||
accumulated_loss = 0
|
accumulated_loss = 0
|
||||||
correct_sum = 0
|
correct_sum = 0
|
||||||
total_sum = 0
|
total_sum = 0
|
||||||
|
num_steps = len(test_dataloader)
|
||||||
|
data_iter = iter(test_dataloader)
|
||||||
|
|
||||||
for i in range(engine.schedule.num_steps):
|
for i in range(num_steps):
|
||||||
output, label, loss = engine.step()
|
output, label, loss = engine.step(data_iter)
|
||||||
|
|
||||||
if gpc.is_last_rank(ParallelMode.PIPELINE):
|
if gpc.is_last_rank(ParallelMode.PIPELINE):
|
||||||
accumulated_loss += loss.detach().cpu().numpy()
|
accumulated_loss += loss.detach().cpu().numpy()
|
||||||
|
@ -43,21 +45,23 @@ def eval(engine):
|
||||||
correct = torch.sum(label == output)
|
correct = torch.sum(label == output)
|
||||||
correct_sum += correct
|
correct_sum += correct
|
||||||
total_sum += label.size(0)
|
total_sum += label.size(0)
|
||||||
avg_loss = accumulated_loss / engine.schedule.num_steps
|
avg_loss = accumulated_loss / num_steps
|
||||||
return correct_sum, total_sum, avg_loss
|
return correct_sum, total_sum, avg_loss
|
||||||
|
|
||||||
|
|
||||||
def train(engine):
|
def train(engine, train_dataloader):
|
||||||
engine.train()
|
engine.train()
|
||||||
accumulated_loss = 0
|
accumulated_loss = 0
|
||||||
|
num_steps = len(train_dataloader)
|
||||||
|
data_iter = iter(train_dataloader)
|
||||||
|
|
||||||
for i in range(engine.schedule.num_steps):
|
for i in range(num_steps):
|
||||||
output, label, loss = engine.step()
|
output, label, loss = engine.step(data_iter)
|
||||||
|
|
||||||
if gpc.is_last_rank(ParallelMode.PIPELINE):
|
if gpc.is_last_rank(ParallelMode.PIPELINE):
|
||||||
accumulated_loss += loss.detach().cpu().numpy()
|
accumulated_loss += loss.detach().cpu().numpy()
|
||||||
|
|
||||||
avg_loss = accumulated_loss / engine.schedule.num_steps
|
avg_loss = accumulated_loss / num_steps
|
||||||
return avg_loss
|
return avg_loss
|
||||||
|
|
||||||
|
|
||||||
|
@ -65,25 +69,16 @@ def train(engine):
|
||||||
@pytest.mark.skip("This test should be invoked by test.sh in the same folder as it runs on multiple gpus")
|
@pytest.mark.skip("This test should be invoked by test.sh in the same folder as it runs on multiple gpus")
|
||||||
def test_2p5d_parallel_vision_transformer():
|
def test_2p5d_parallel_vision_transformer():
|
||||||
# init dist
|
# init dist
|
||||||
model, train_dataloader, test_dataloader, criterion, optimizer, schedule, lr_scheduler = colossalai.initialize(
|
engine, train_dataloader, test_dataloader = colossalai.initialize(CONFIG_PATH)
|
||||||
CONFIG_PATH)
|
|
||||||
logger = get_global_dist_logger()
|
logger = get_global_dist_logger()
|
||||||
|
|
||||||
engine = Engine(model=model,
|
|
||||||
train_dataloader=train_dataloader,
|
|
||||||
test_dataloader=test_dataloader,
|
|
||||||
criterion=criterion,
|
|
||||||
optimizer=optimizer,
|
|
||||||
lr_scheduler=lr_scheduler,
|
|
||||||
schedule=schedule)
|
|
||||||
|
|
||||||
for epoch in range(gpc.config.num_epochs):
|
for epoch in range(gpc.config.num_epochs):
|
||||||
train_loss = train(engine)
|
train_loss = train(engine, train_dataloader)
|
||||||
if gpc.is_last_rank(ParallelMode.PIPELINE):
|
if gpc.is_last_rank(ParallelMode.PIPELINE):
|
||||||
logger.info(f'epoch {epoch} - train loss: {train_loss}')
|
logger.info(f'epoch {epoch} - train loss: {train_loss}')
|
||||||
|
|
||||||
if epoch % 2 == 0:
|
if epoch % 2 == 0:
|
||||||
correct_sum, total_sum, eval_loss = eval(engine)
|
correct_sum, total_sum, eval_loss = eval(engine, test_dataloader)
|
||||||
if gpc.is_last_rank(ParallelMode.PIPELINE):
|
if gpc.is_last_rank(ParallelMode.PIPELINE):
|
||||||
logger.info(
|
logger.info(
|
||||||
f'epoch {epoch} - eval loss: {eval_loss}, total: {total_sum}, '
|
f'epoch {epoch} - eval loss: {eval_loss}, total: {total_sum}, '
|
||||||
|
|
|
@ -38,5 +38,3 @@ optimizer = dict(type='Adam', lr=0.001)
|
||||||
|
|
||||||
loss = dict(type='CrossEntropyLoss')
|
loss = dict(type='CrossEntropyLoss')
|
||||||
|
|
||||||
# set_device_func = lambda global_rank, world_size: global_rank % 4
|
|
||||||
seed = 1024
|
|
||||||
|
|
|
@ -40,6 +40,3 @@ optimizer = dict(type='Adam', lr=0.001)
|
||||||
|
|
||||||
loss = dict(type='CrossEntropyLoss')
|
loss = dict(type='CrossEntropyLoss')
|
||||||
fp16 = dict(mode=AMP_TYPE.APEX)
|
fp16 = dict(mode=AMP_TYPE.APEX)
|
||||||
|
|
||||||
# set_device_func = lambda global_rank, world_size: global_rank % 4
|
|
||||||
seed = 1024
|
|
||||||
|
|
|
@ -40,6 +40,3 @@ optimizer = dict(type='Adam', lr=0.001)
|
||||||
|
|
||||||
loss = dict(type='CrossEntropyLoss')
|
loss = dict(type='CrossEntropyLoss')
|
||||||
fp16 = dict(mode=AMP_TYPE.TORCH)
|
fp16 = dict(mode=AMP_TYPE.TORCH)
|
||||||
|
|
||||||
# set_device_func = lambda global_rank, world_size: global_rank % 4
|
|
||||||
seed = 1024
|
|
||||||
|
|
|
@ -38,11 +38,9 @@ parallel = dict(
|
||||||
tensor=dict(size=1, mode=None)
|
tensor=dict(size=1, mode=None)
|
||||||
)
|
)
|
||||||
|
|
||||||
schedule = dict(
|
engine = dict(
|
||||||
|
schedule=dict(
|
||||||
num_microbatches=4
|
num_microbatches=4
|
||||||
|
)
|
||||||
)
|
)
|
||||||
num_pipeling_batches = 2
|
|
||||||
seed = 1024
|
|
||||||
lr_scheduler = dict(type='LinearWarmupLR', warmup_steps=5)
|
|
||||||
|
|
||||||
num_epochs = 10
|
num_epochs = 10
|
||||||
|
|
|
@ -8,7 +8,6 @@ import torch
|
||||||
|
|
||||||
from colossalai import initialize
|
from colossalai import initialize
|
||||||
from colossalai.core import global_context as gpc
|
from colossalai.core import global_context as gpc
|
||||||
from colossalai.engine import Engine
|
|
||||||
from colossalai.logging import get_global_dist_logger
|
from colossalai.logging import get_global_dist_logger
|
||||||
from colossalai.utils import report_memory_usage
|
from colossalai.utils import report_memory_usage
|
||||||
|
|
||||||
|
@ -24,20 +23,13 @@ NO_PIPE_CONFIG_PATH = osp.join(DIR_PATH, '../configs/non_pipeline_resnet_apex_am
|
||||||
|
|
||||||
|
|
||||||
def run_no_pipeline(config):
|
def run_no_pipeline(config):
|
||||||
model, train_dataloader, test_dataloader, criterion, optimizer, schedule, lr_scheduler = initialize(config)
|
engine, train_dataloader, test_dataloader = initialize(config)
|
||||||
logger = get_global_dist_logger()
|
logger = get_global_dist_logger()
|
||||||
rank = torch.distributed.get_rank()
|
rank = torch.distributed.get_rank()
|
||||||
|
|
||||||
engine = Engine(model=model,
|
|
||||||
train_dataloader=train_dataloader,
|
|
||||||
criterion=criterion,
|
|
||||||
optimizer=optimizer,
|
|
||||||
schedule=schedule)
|
|
||||||
engine.train()
|
engine.train()
|
||||||
logger.info('lr = %g' % engine.get_lr())
|
output, label, loss = engine.step(iter(train_dataloader))
|
||||||
output, label, loss = engine.step()
|
|
||||||
logger.info('Rank {} returns: {}'.format(rank, loss.item()))
|
logger.info('Rank {} returns: {}'.format(rank, loss.item()))
|
||||||
logger.info('lr = %g' % engine.get_lr())
|
|
||||||
|
|
||||||
gpc.destroy()
|
gpc.destroy()
|
||||||
logger.info('Test engine finished')
|
logger.info('Test engine finished')
|
||||||
|
|
|
@ -8,7 +8,6 @@ import torch
|
||||||
|
|
||||||
from colossalai import initialize
|
from colossalai import initialize
|
||||||
from colossalai.core import global_context as gpc
|
from colossalai.core import global_context as gpc
|
||||||
from colossalai.engine import Engine
|
|
||||||
from colossalai.logging import get_global_dist_logger
|
from colossalai.logging import get_global_dist_logger
|
||||||
from colossalai.utils import report_memory_usage
|
from colossalai.utils import report_memory_usage
|
||||||
|
|
||||||
|
@ -26,21 +25,14 @@ NO_PIPE_CONFIG_PATH = osp.join(DIR_PATH, '../configs/non_pipeline_resnet.py')
|
||||||
def test_no_pipeline(config):
|
def test_no_pipeline(config):
|
||||||
print('Test no pipeline engine start')
|
print('Test no pipeline engine start')
|
||||||
|
|
||||||
model, train_dataloader, test_dataloader, criterion, optimizer, schedule, lr_scheduler = initialize(config)
|
engine, train_dataloader, test_dataloader = initialize(config)
|
||||||
logger = get_global_dist_logger()
|
logger = get_global_dist_logger()
|
||||||
|
|
||||||
rank = torch.distributed.get_rank()
|
rank = torch.distributed.get_rank()
|
||||||
engine = Engine(model=model,
|
|
||||||
train_dataloader=train_dataloader,
|
|
||||||
criterion=criterion,
|
|
||||||
optimizer=optimizer,
|
|
||||||
schedule=schedule)
|
|
||||||
|
|
||||||
engine.train()
|
engine.train()
|
||||||
logger.info('lr = %g' % engine.get_lr())
|
output, label, loss = engine.step(iter(train_dataloader))
|
||||||
output, label, loss = engine.step()
|
|
||||||
logger.info('Rank {} returns: {}'.format(rank, loss.item()))
|
logger.info('Rank {} returns: {}'.format(rank, loss.item()))
|
||||||
logger.info('lr = %g' % engine.get_lr())
|
|
||||||
|
|
||||||
gpc.destroy()
|
gpc.destroy()
|
||||||
logger.info('Test engine finished')
|
logger.info('Test engine finished')
|
||||||
|
|
|
@ -8,7 +8,6 @@ import torch
|
||||||
|
|
||||||
from colossalai import initialize
|
from colossalai import initialize
|
||||||
from colossalai.core import global_context as gpc
|
from colossalai.core import global_context as gpc
|
||||||
from colossalai.engine import Engine
|
|
||||||
from colossalai.logging import get_global_dist_logger
|
from colossalai.logging import get_global_dist_logger
|
||||||
from colossalai.utils import report_memory_usage
|
from colossalai.utils import report_memory_usage
|
||||||
|
|
||||||
|
@ -26,21 +25,13 @@ NO_PIPE_CONFIG_PATH = osp.join(DIR_PATH, '../configs/non_pipeline_resnet_torch_a
|
||||||
def test_no_pipeline(config):
|
def test_no_pipeline(config):
|
||||||
print('Test no pipeline engine start')
|
print('Test no pipeline engine start')
|
||||||
|
|
||||||
model, train_dataloader, test_dataloader, criterion, optimizer, schedule, lr_scheduler = initialize(config)
|
engine, train_dataloader, test_dataloader = initialize(config)
|
||||||
logger = get_global_dist_logger()
|
logger = get_global_dist_logger()
|
||||||
|
|
||||||
rank = torch.distributed.get_rank()
|
rank = torch.distributed.get_rank()
|
||||||
engine = Engine(model=model,
|
|
||||||
train_dataloader=train_dataloader,
|
|
||||||
criterion=criterion,
|
|
||||||
optimizer=optimizer,
|
|
||||||
schedule=schedule)
|
|
||||||
|
|
||||||
engine.train()
|
engine.train()
|
||||||
logger.info('lr = %g' % engine.get_lr())
|
output, label, loss = engine.step(iter(train_dataloader))
|
||||||
output, label, loss = engine.step()
|
|
||||||
logger.info('Rank {} returns: {}'.format(rank, loss.item()))
|
logger.info('Rank {} returns: {}'.format(rank, loss.item()))
|
||||||
logger.info('lr = %g' % engine.get_lr())
|
|
||||||
|
|
||||||
gpc.destroy()
|
gpc.destroy()
|
||||||
logger.info('Test engine finished')
|
logger.info('Test engine finished')
|
||||||
|
|
|
@ -5,6 +5,7 @@ import os.path as osp
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
from colossalai.context import ParallelMode
|
||||||
from colossalai.core import global_context as gpc
|
from colossalai.core import global_context as gpc
|
||||||
from colossalai.initialize import initialize
|
from colossalai.initialize import initialize
|
||||||
from colossalai.logging import get_global_dist_logger
|
from colossalai.logging import get_global_dist_logger
|
||||||
|
@ -22,13 +23,25 @@ CONFIG_PATH = osp.join(DIR_PATH, '../configs/pipeline_vanilla_resnet.py')
|
||||||
@pytest.mark.skip("This test should be invoked using the test.sh provided")
|
@pytest.mark.skip("This test should be invoked using the test.sh provided")
|
||||||
@pytest.mark.dist
|
@pytest.mark.dist
|
||||||
def test_schedule():
|
def test_schedule():
|
||||||
model, train_dataloader, test_dataloader, criterion, optimizer, schedule, lr_scheduler = initialize(CONFIG_PATH)
|
engine, train_dataloader, test_dataloader = initialize(CONFIG_PATH)
|
||||||
logger = get_global_dist_logger()
|
logger = get_global_dist_logger()
|
||||||
|
|
||||||
schedule.zero_grad()
|
model = engine.model
|
||||||
output, label, losses = schedule.forward_backward_step(forward_only=False)
|
optimizer = engine.optimizer
|
||||||
schedule.step()
|
criterion = engine.criterion
|
||||||
logger.info('losses: {}'.format([loss.item() for loss in losses]))
|
schedule = engine._schedule
|
||||||
|
|
||||||
|
output, label, loss = schedule.forward_backward_step(
|
||||||
|
data_iter=iter(train_dataloader),
|
||||||
|
model=model,
|
||||||
|
optimizer=optimizer,
|
||||||
|
criterion=criterion,
|
||||||
|
forward_only=False
|
||||||
|
)
|
||||||
|
schedule.optimizer_step(model, optimizer)
|
||||||
|
|
||||||
|
if gpc.is_last_rank(ParallelMode.PIPELINE):
|
||||||
|
logger.info('losses: {}'.format(loss))
|
||||||
|
|
||||||
gpc.destroy()
|
gpc.destroy()
|
||||||
logger.info('training finished')
|
logger.info('training finished')
|
||||||
|
|
|
@ -9,7 +9,6 @@ import torch
|
||||||
from colossalai import initialize
|
from colossalai import initialize
|
||||||
from colossalai.context import ParallelMode
|
from colossalai.context import ParallelMode
|
||||||
from colossalai.core import global_context as gpc
|
from colossalai.core import global_context as gpc
|
||||||
from colossalai.engine import Engine
|
|
||||||
from colossalai.logging import get_global_dist_logger
|
from colossalai.logging import get_global_dist_logger
|
||||||
|
|
||||||
NUM_BATCH = 128
|
NUM_BATCH = 128
|
||||||
|
@ -23,22 +22,14 @@ PIPE_CONFIG_PATH = osp.join(DIR_PATH, '../configs/pipeline_vanilla_resnet.py')
|
||||||
|
|
||||||
|
|
||||||
def run_pipeline(config):
|
def run_pipeline(config):
|
||||||
model, train_dataloader, test_dataloader, criterion, optimizer, schedule, lr_scheduler = initialize(config)
|
engine, train_dataloader, test_dataloader = initialize(config)
|
||||||
logger = get_global_dist_logger()
|
logger = get_global_dist_logger()
|
||||||
rank = torch.distributed.get_rank()
|
rank = torch.distributed.get_rank()
|
||||||
engine = Engine(model=model,
|
|
||||||
train_dataloader=train_dataloader,
|
|
||||||
criterion=criterion,
|
|
||||||
optimizer=optimizer,
|
|
||||||
lr_scheduler=lr_scheduler,
|
|
||||||
schedule=schedule)
|
|
||||||
|
|
||||||
engine.train()
|
engine.train()
|
||||||
logger.info('lr = %g' % engine.get_lr())
|
outputs, labels, loss = engine.step(iter(train_dataloader))
|
||||||
outputs, labels, loss = engine.step()
|
|
||||||
if gpc.is_last_rank(ParallelMode.PIPELINE):
|
if gpc.is_last_rank(ParallelMode.PIPELINE):
|
||||||
logger.info('losses: {}'.format(rank, loss.item()))
|
logger.info('losses: {}'.format(rank, loss.item()))
|
||||||
logger.info('lr = %g' % engine.get_lr())
|
|
||||||
|
|
||||||
gpc.destroy()
|
gpc.destroy()
|
||||||
logger.info('Test engine pipeline finished')
|
logger.info('Test engine pipeline finished')
|
||||||
|
|
|
@ -132,9 +132,12 @@ fp16 = dict(
|
||||||
initial_scale=2 ** 4
|
initial_scale=2 ** 4
|
||||||
)
|
)
|
||||||
|
|
||||||
|
num_epochs = 60
|
||||||
|
|
||||||
|
|
||||||
lr_scheduler = dict(
|
lr_scheduler = dict(
|
||||||
type='LinearWarmupLR',
|
type='LinearWarmupLR',
|
||||||
warmup_epochs=5
|
warmup_steps=5,
|
||||||
|
total_steps=num_epochs
|
||||||
)
|
)
|
||||||
|
|
||||||
num_epochs = 60
|
|
||||||
|
|
|
@ -7,23 +7,25 @@ import pytest
|
||||||
import torch.autograd
|
import torch.autograd
|
||||||
|
|
||||||
import colossalai
|
import colossalai
|
||||||
|
from colossalai.builder import build_lr_scheduler
|
||||||
from colossalai.context.parallel_mode import ParallelMode
|
from colossalai.context.parallel_mode import ParallelMode
|
||||||
from colossalai.core import global_context as gpc
|
from colossalai.core import global_context as gpc
|
||||||
from colossalai.engine import Engine
|
|
||||||
from colossalai.logging import get_global_dist_logger
|
from colossalai.logging import get_global_dist_logger
|
||||||
from colossalai.nn.layer._parallel_utilities import _gather
|
from colossalai.nn.layer._parallel_utilities import _gather
|
||||||
|
|
||||||
CONFIG_PATH = Path(__file__).parent.parent.joinpath('configs/vit_2d.py')
|
CONFIG_PATH = Path(__file__).parent.parent.joinpath('configs/vit_2d.py')
|
||||||
|
|
||||||
|
|
||||||
def eval(engine):
|
def eval(engine, test_dataloader):
|
||||||
engine.eval()
|
engine.eval()
|
||||||
accumulated_loss = 0
|
accumulated_loss = 0
|
||||||
correct_sum = 0
|
correct_sum = 0
|
||||||
total_sum = 0
|
total_sum = 0
|
||||||
|
num_steps = len(test_dataloader)
|
||||||
|
data_iter = iter(test_dataloader)
|
||||||
|
|
||||||
for i in range(engine.schedule.num_steps):
|
for i in range(num_steps):
|
||||||
output, label, loss = engine.step()
|
output, label, loss = engine.step(data_iter)
|
||||||
accumulated_loss += loss.detach().cpu().numpy()
|
accumulated_loss += loss.detach().cpu().numpy()
|
||||||
|
|
||||||
output = _gather(
|
output = _gather(
|
||||||
|
@ -40,18 +42,21 @@ def eval(engine):
|
||||||
correct = torch.sum(label[0] == output)
|
correct = torch.sum(label[0] == output)
|
||||||
correct_sum += correct
|
correct_sum += correct
|
||||||
total_sum += label[0].size(0)
|
total_sum += label[0].size(0)
|
||||||
avg_loss = accumulated_loss / engine.schedule.num_steps
|
avg_loss = accumulated_loss / num_steps
|
||||||
return correct_sum, total_sum, avg_loss
|
return correct_sum, total_sum, avg_loss
|
||||||
|
|
||||||
|
|
||||||
def train(engine):
|
def train(engine, train_dataloader, lr_scheduler):
|
||||||
engine.train()
|
engine.train()
|
||||||
accumulated_loss = 0
|
accumulated_loss = 0
|
||||||
|
num_steps = len(train_dataloader)
|
||||||
|
data_iter = iter(train_dataloader)
|
||||||
|
|
||||||
for i in range(engine.schedule.num_steps):
|
for i in range(num_steps):
|
||||||
output, label, loss = engine.step()
|
output, label, loss = engine.step(data_iter)
|
||||||
accumulated_loss += loss.squeeze(0).detach().cpu().numpy()
|
accumulated_loss += loss.squeeze(0).detach().cpu().numpy()
|
||||||
avg_loss = accumulated_loss / engine.schedule.num_steps
|
avg_loss = accumulated_loss / num_steps
|
||||||
|
lr_scheduler.step()
|
||||||
return avg_loss
|
return avg_loss
|
||||||
|
|
||||||
|
|
||||||
|
@ -59,26 +64,18 @@ def train(engine):
|
||||||
@pytest.mark.skip("This test should be invoked by test.sh in the same folder as it runs on multiple gpus")
|
@pytest.mark.skip("This test should be invoked by test.sh in the same folder as it runs on multiple gpus")
|
||||||
def test_2d_parallel_vision_transformer():
|
def test_2d_parallel_vision_transformer():
|
||||||
# init dist
|
# init dist
|
||||||
model, train_dataloader, test_dataloader, criterion, optimizer, schedule, lr_scheduler = colossalai.initialize(
|
engine, train_dataloader, test_dataloader = colossalai.initialize(CONFIG_PATH)
|
||||||
CONFIG_PATH)
|
lr_scheduler = build_lr_scheduler(gpc.config.lr_scheduler, engine.optimizer)
|
||||||
logger = get_global_dist_logger()
|
logger = get_global_dist_logger()
|
||||||
|
|
||||||
engine = Engine(model=model,
|
|
||||||
train_dataloader=train_dataloader,
|
|
||||||
test_dataloader=test_dataloader,
|
|
||||||
criterion=criterion,
|
|
||||||
optimizer=optimizer,
|
|
||||||
lr_scheduler=lr_scheduler,
|
|
||||||
schedule=schedule)
|
|
||||||
|
|
||||||
logger.info('start training')
|
logger.info('start training')
|
||||||
for epoch in range(gpc.config.num_epochs):
|
for epoch in range(gpc.config.num_epochs):
|
||||||
train_loss = train(engine)
|
train_loss = train(engine, train_dataloader, lr_scheduler)
|
||||||
|
|
||||||
logger.info(f'epoch {epoch} - train loss: {train_loss}')
|
logger.info(f'epoch {epoch} - train loss: {train_loss}')
|
||||||
|
|
||||||
if epoch % 2 == 0:
|
if epoch % 2 == 0:
|
||||||
correct_sum, total_sum, eval_loss = eval(engine)
|
correct_sum, total_sum, eval_loss = eval(engine, test_dataloader)
|
||||||
logger.info(
|
logger.info(
|
||||||
f'epoch {epoch} - eval loss: {eval_loss}, total: {total_sum}, '
|
f'epoch {epoch} - eval loss: {eval_loss}, total: {total_sum}, '
|
||||||
f'correct: {correct_sum}, acc: {correct_sum / total_sum}')
|
f'correct: {correct_sum}, acc: {correct_sum / total_sum}')
|
||||||
|
|
|
@ -102,6 +102,6 @@ parallel = dict(
|
||||||
tensor=dict(size=4, mode='2d'),
|
tensor=dict(size=4, mode='2d'),
|
||||||
)
|
)
|
||||||
|
|
||||||
lr_scheduler = dict(type='LinearWarmupLR', warmup_epochs=5)
|
|
||||||
|
|
||||||
num_epochs = 60
|
num_epochs = 60
|
||||||
|
|
||||||
|
lr_scheduler = dict(type='LinearWarmupLR', warmup_steps=5, total_steps=num_epochs)
|
||||||
|
|
|
@ -125,13 +125,6 @@ parallel = dict(
|
||||||
tensor=dict(size=4, depth=1, mode='2.5d'),
|
tensor=dict(size=4, depth=1, mode='2.5d'),
|
||||||
)
|
)
|
||||||
|
|
||||||
lr_scheduler = dict(
|
|
||||||
type='LinearWarmupLR',
|
|
||||||
warmup_epochs=5
|
|
||||||
)
|
|
||||||
|
|
||||||
schedule = dict(
|
|
||||||
num_microbatches=8
|
|
||||||
)
|
|
||||||
|
|
||||||
num_epochs = 60
|
num_epochs = 60
|
||||||
|
|
||||||
|
lr_scheduler = dict(type='LinearWarmupLR', warmup_steps=5, total_steps=num_epochs)
|
||||||
|
|
|
@ -116,9 +116,14 @@ hooks = [
|
||||||
weight_parallel_mode=ParallelMode.PARALLEL_3D_WEIGHT,
|
weight_parallel_mode=ParallelMode.PARALLEL_3D_WEIGHT,
|
||||||
),
|
),
|
||||||
dict(type='LossHook'),
|
dict(type='LossHook'),
|
||||||
# dict(type='TensorboardHook', log_dir='./tfb_logs'),
|
dict(
|
||||||
# dict(type='SaveCheckpointHook', interval=5, checkpoint_dir='./ckpt'),
|
type='LRSchedulerHook',
|
||||||
# dict(type='LoadCheckpointHook', epoch=20, checkpoint_dir='./ckpt')
|
by_epoch=True,
|
||||||
|
lr_scheduler_cfg=dict(
|
||||||
|
type='LinearWarmupLR',
|
||||||
|
warmup_steps=5
|
||||||
|
)
|
||||||
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
parallel = dict(
|
parallel = dict(
|
||||||
|
@ -127,12 +132,4 @@ parallel = dict(
|
||||||
tensor=dict(mode='3d', size=8),
|
tensor=dict(mode='3d', size=8),
|
||||||
)
|
)
|
||||||
|
|
||||||
# fp16 = dict(mode=AMP_TYPE.PARALLEL, initial_scale=2 ** 6)
|
|
||||||
|
|
||||||
lr_scheduler = dict(type='LinearWarmupLR', warmup_epochs=5)
|
|
||||||
|
|
||||||
# schedule = dict(num_microbatches=4)
|
|
||||||
|
|
||||||
num_epochs = 60
|
num_epochs = 60
|
||||||
|
|
||||||
seed = 42
|
|
||||||
|
|
|
@ -7,23 +7,25 @@ import pytest
|
||||||
import torch.autograd
|
import torch.autograd
|
||||||
|
|
||||||
import colossalai
|
import colossalai
|
||||||
|
from colossalai.builder import build_lr_scheduler
|
||||||
from colossalai.context.parallel_mode import ParallelMode
|
from colossalai.context.parallel_mode import ParallelMode
|
||||||
from colossalai.core import global_context as gpc
|
from colossalai.core import global_context as gpc
|
||||||
from colossalai.engine import Engine
|
|
||||||
from colossalai.logging import get_global_dist_logger
|
from colossalai.logging import get_global_dist_logger
|
||||||
from colossalai.nn.layer._parallel_utilities import _gather
|
from colossalai.nn.layer._parallel_utilities import _gather
|
||||||
|
|
||||||
CONFIG_PATH = Path(__file__).parent.parent.joinpath('configs/vit_2d.py')
|
CONFIG_PATH = Path(__file__).parent.parent.joinpath('configs/vit_2d.py')
|
||||||
|
|
||||||
|
|
||||||
def eval(engine):
|
def eval(engine, test_dataloader):
|
||||||
engine.eval()
|
engine.eval()
|
||||||
accumulated_loss = 0
|
accumulated_loss = 0
|
||||||
correct_sum = 0
|
correct_sum = 0
|
||||||
total_sum = 0
|
total_sum = 0
|
||||||
|
num_steps = len(test_dataloader)
|
||||||
|
data_iter = iter(test_dataloader)
|
||||||
|
|
||||||
for i in range(engine.schedule.num_steps):
|
for i in range(num_steps):
|
||||||
output, label, loss = engine.step()
|
output, label, loss = engine.step(data_iter)
|
||||||
accumulated_loss += loss.detach().cpu().numpy()
|
accumulated_loss += loss.detach().cpu().numpy()
|
||||||
|
|
||||||
output = _gather(
|
output = _gather(
|
||||||
|
@ -40,18 +42,21 @@ def eval(engine):
|
||||||
correct = torch.sum(label[0] == output)
|
correct = torch.sum(label[0] == output)
|
||||||
correct_sum += correct
|
correct_sum += correct
|
||||||
total_sum += label[0].size(0)
|
total_sum += label[0].size(0)
|
||||||
avg_loss = accumulated_loss / engine.schedule.num_steps
|
avg_loss = accumulated_loss / num_steps
|
||||||
return correct_sum, total_sum, avg_loss
|
return correct_sum, total_sum, avg_loss
|
||||||
|
|
||||||
|
|
||||||
def train(engine):
|
def train(engine, train_dataloader, lr_scheduler):
|
||||||
engine.train()
|
engine.train()
|
||||||
accumulated_loss = 0
|
accumulated_loss = 0
|
||||||
|
num_steps = len(train_dataloader)
|
||||||
|
data_iter = iter(train_dataloader)
|
||||||
|
|
||||||
for i in range(engine.schedule.num_steps):
|
for i in range(num_steps):
|
||||||
output, label, loss = engine.step()
|
output, label, loss = engine.step(data_iter)
|
||||||
accumulated_loss += loss.detach().cpu().numpy()
|
accumulated_loss += loss.detach().cpu().numpy()
|
||||||
avg_loss = accumulated_loss / engine.schedule.num_steps
|
avg_loss = accumulated_loss / num_steps
|
||||||
|
lr_scheduler.step()
|
||||||
return avg_loss
|
return avg_loss
|
||||||
|
|
||||||
|
|
||||||
|
@ -59,25 +64,17 @@ def train(engine):
|
||||||
@pytest.mark.skip("This test should be invoked by test.sh in the same folder as it runs on multiple gpus")
|
@pytest.mark.skip("This test should be invoked by test.sh in the same folder as it runs on multiple gpus")
|
||||||
def test_2d_parallel_vision_transformer():
|
def test_2d_parallel_vision_transformer():
|
||||||
# init dist
|
# init dist
|
||||||
model, train_dataloader, test_dataloader, criterion, optimizer, schedule, lr_scheduler = colossalai.initialize(
|
engine, train_dataloader, test_dataloader = colossalai.initialize(CONFIG_PATH)
|
||||||
CONFIG_PATH)
|
lr_scheduler = build_lr_scheduler(gpc.config.lr_scheduler, engine.optimizer)
|
||||||
logger = get_global_dist_logger()
|
logger = get_global_dist_logger()
|
||||||
|
|
||||||
engine = Engine(model=model,
|
|
||||||
train_dataloader=train_dataloader,
|
|
||||||
test_dataloader=test_dataloader,
|
|
||||||
criterion=criterion,
|
|
||||||
optimizer=optimizer,
|
|
||||||
lr_scheduler=lr_scheduler,
|
|
||||||
schedule=schedule)
|
|
||||||
|
|
||||||
logger.info('start training')
|
logger.info('start training')
|
||||||
for epoch in range(gpc.config.num_epochs):
|
for epoch in range(gpc.config.num_epochs):
|
||||||
train_loss = train(engine)
|
train_loss = train(engine, train_dataloader, lr_scheduler)
|
||||||
logger.info(f'epoch {epoch} - train loss: {train_loss}')
|
logger.info(f'epoch {epoch} - train loss: {train_loss}')
|
||||||
|
|
||||||
if epoch % 2 == 0:
|
if epoch % 2 == 0:
|
||||||
correct_sum, total_sum, eval_loss = eval(engine)
|
correct_sum, total_sum, eval_loss = eval(engine, test_dataloader)
|
||||||
logger.info(
|
logger.info(
|
||||||
f'epoch {epoch} - eval loss: {eval_loss}, total: {total_sum}, '
|
f'epoch {epoch} - eval loss: {eval_loss}, total: {total_sum}, '
|
||||||
f'correct: {correct_sum}, acc: {correct_sum / total_sum}')
|
f'correct: {correct_sum}, acc: {correct_sum / total_sum}')
|
||||||
|
|
|
@ -4,22 +4,25 @@ import pytest
|
||||||
import torch.autograd
|
import torch.autograd
|
||||||
|
|
||||||
import colossalai
|
import colossalai
|
||||||
|
from colossalai.builder import build_lr_scheduler
|
||||||
from colossalai.context.parallel_mode import ParallelMode
|
from colossalai.context.parallel_mode import ParallelMode
|
||||||
from colossalai.core import global_context as gpc
|
from colossalai.core import global_context as gpc
|
||||||
from colossalai.engine import Engine
|
|
||||||
from colossalai.logging import get_global_dist_logger
|
from colossalai.logging import get_global_dist_logger
|
||||||
from colossalai.nn.layer._parallel_utilities import _gather
|
from colossalai.nn.layer._parallel_utilities import _gather
|
||||||
|
|
||||||
CONFIG_PATH = Path(__file__).parent.parent.joinpath('configs/vit_2p5d.py')
|
CONFIG_PATH = Path(__file__).parent.parent.joinpath('configs/vit_2p5d.py')
|
||||||
|
|
||||||
def eval(engine):
|
|
||||||
|
def eval(engine, test_dataloader):
|
||||||
engine.eval()
|
engine.eval()
|
||||||
accumulated_loss = 0
|
accumulated_loss = 0
|
||||||
correct_sum = 0
|
correct_sum = 0
|
||||||
total_sum = 0
|
total_sum = 0
|
||||||
|
num_steps = len(test_dataloader)
|
||||||
|
data_iter = iter(test_dataloader)
|
||||||
|
|
||||||
for i in range(engine.schedule.num_steps):
|
for i in range(num_steps):
|
||||||
output, label, loss = engine.step()
|
output, label, loss = engine.step(data_iter)
|
||||||
accumulated_loss += loss.detach().cpu().numpy()
|
accumulated_loss += loss.detach().cpu().numpy()
|
||||||
|
|
||||||
output = _gather(
|
output = _gather(
|
||||||
|
@ -41,18 +44,21 @@ def eval(engine):
|
||||||
correct = torch.sum(label[0] == output)
|
correct = torch.sum(label[0] == output)
|
||||||
correct_sum += correct
|
correct_sum += correct
|
||||||
total_sum += label[0].size(0)
|
total_sum += label[0].size(0)
|
||||||
avg_loss = accumulated_loss / engine.schedule.num_steps
|
avg_loss = accumulated_loss / num_steps
|
||||||
return correct_sum, total_sum, avg_loss
|
return correct_sum, total_sum, avg_loss
|
||||||
|
|
||||||
|
|
||||||
def train(engine):
|
def train(engine, train_dataloader, lr_scheduler):
|
||||||
engine.train()
|
engine.train()
|
||||||
accumulated_loss = 0
|
accumulated_loss = 0
|
||||||
|
num_steps = len(train_dataloader)
|
||||||
|
data_iter = iter(train_dataloader)
|
||||||
|
|
||||||
for i in range(engine.schedule.num_steps):
|
for i in range(num_steps):
|
||||||
output, label, loss = engine.step()
|
output, label, loss = engine.step(data_iter)
|
||||||
accumulated_loss += loss.detach().cpu().numpy()
|
accumulated_loss += loss.detach().cpu().numpy()
|
||||||
avg_loss = accumulated_loss / engine.schedule.num_steps
|
avg_loss = accumulated_loss / num_steps
|
||||||
|
lr_scheduler.step()
|
||||||
return avg_loss
|
return avg_loss
|
||||||
|
|
||||||
|
|
||||||
|
@ -60,25 +66,17 @@ def train(engine):
|
||||||
@pytest.mark.skip("This test should be invoked by test.sh in the same folder as it runs on multiple gpus")
|
@pytest.mark.skip("This test should be invoked by test.sh in the same folder as it runs on multiple gpus")
|
||||||
def test_2p5d_parallel_vision_transformer():
|
def test_2p5d_parallel_vision_transformer():
|
||||||
# init dist
|
# init dist
|
||||||
model, train_dataloader, test_dataloader, criterion, optimizer, schedule, lr_scheduler = colossalai.initialize(
|
engine, train_dataloader, test_dataloader = colossalai.initialize(CONFIG_PATH)
|
||||||
CONFIG_PATH)
|
lr_scheduler = build_lr_scheduler(gpc.config.lr_scheduler, engine.optimizer)
|
||||||
logger = get_global_dist_logger()
|
logger = get_global_dist_logger()
|
||||||
|
|
||||||
engine = Engine(model=model,
|
|
||||||
train_dataloader=train_dataloader,
|
|
||||||
test_dataloader=test_dataloader,
|
|
||||||
criterion=criterion,
|
|
||||||
optimizer=optimizer,
|
|
||||||
lr_scheduler=lr_scheduler,
|
|
||||||
schedule=schedule)
|
|
||||||
|
|
||||||
logger.info('start training')
|
logger.info('start training')
|
||||||
for epoch in range(gpc.config.num_epochs):
|
for epoch in range(gpc.config.num_epochs):
|
||||||
train_loss = train(engine)
|
train_loss = train(engine, train_dataloader, lr_scheduler)
|
||||||
logger.info(f'epoch {epoch} - train loss: {train_loss}')
|
logger.info(f'epoch {epoch} - train loss: {train_loss}')
|
||||||
|
|
||||||
if epoch % 2 == 0:
|
if epoch % 2 == 0:
|
||||||
correct_sum, total_sum, eval_loss = eval(engine)
|
correct_sum, total_sum, eval_loss = eval(engine, test_dataloader)
|
||||||
logger.info(
|
logger.info(
|
||||||
f'epoch {epoch} - eval loss: {eval_loss}, total: {total_sum}, '
|
f'epoch {epoch} - eval loss: {eval_loss}, total: {total_sum}, '
|
||||||
f'correct: {correct_sum}, acc: {correct_sum / total_sum}')
|
f'correct: {correct_sum}, acc: {correct_sum / total_sum}')
|
||||||
|
|
|
@ -1,16 +1,14 @@
|
||||||
#!/usr/bin/env python
|
#!/usr/bin/env python
|
||||||
# -*- encoding: utf-8 -*-
|
# -*- encoding: utf-8 -*-
|
||||||
|
|
||||||
import time
|
import time
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from colossalai import initialize
|
import colossalai
|
||||||
from colossalai.context import ParallelMode
|
from colossalai.context import ParallelMode
|
||||||
from colossalai.core import global_context as gpc
|
from colossalai.core import global_context as gpc
|
||||||
from colossalai.engine import Engine
|
|
||||||
from colossalai.logging import get_global_dist_logger
|
from colossalai.logging import get_global_dist_logger
|
||||||
from colossalai.trainer import Trainer
|
from colossalai.trainer import Trainer
|
||||||
from colossalai.trainer.metric import Accuracy3D
|
from colossalai.trainer.metric import Accuracy3D
|
||||||
|
@ -29,7 +27,7 @@ def _train_epoch(epoch, engine):
|
||||||
num_samples = 0
|
num_samples = 0
|
||||||
now = time.time()
|
now = time.time()
|
||||||
epoch_start = now
|
epoch_start = now
|
||||||
progress = range(engine.schedule.num_steps)
|
progress = range(engine._schedule.num_steps)
|
||||||
if gpc.get_global_rank() == 0:
|
if gpc.get_global_rank() == 0:
|
||||||
progress = tqdm(progress, desc='[Epoch %d]' % epoch, miniters=1)
|
progress = tqdm(progress, desc='[Epoch %d]' % epoch, miniters=1)
|
||||||
for step in progress:
|
for step in progress:
|
||||||
|
@ -68,7 +66,7 @@ def _eval(epoch, engine):
|
||||||
ParallelMode.PARALLEL_3D_WEIGHT)
|
ParallelMode.PARALLEL_3D_WEIGHT)
|
||||||
total = 0
|
total = 0
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for _ in range(engine.schedule.num_steps):
|
for _ in range(engine._schedule.num_steps):
|
||||||
outputs, targets, loss = engine.step()
|
outputs, targets, loss = engine.step()
|
||||||
if isinstance(outputs, (list, tuple)):
|
if isinstance(outputs, (list, tuple)):
|
||||||
outputs = outputs[0]
|
outputs = outputs[0]
|
||||||
|
@ -80,32 +78,25 @@ def _eval(epoch, engine):
|
||||||
|
|
||||||
print_rank_0(
|
print_rank_0(
|
||||||
'[Epoch %d] Evaluation loss: %.3f | Acc: %.3f%%' %
|
'[Epoch %d] Evaluation loss: %.3f | Acc: %.3f%%' %
|
||||||
(epoch, eval_loss / engine.schedule.num_steps,
|
(epoch, eval_loss / engine._schedule.num_steps,
|
||||||
acc.get_accumulated_value() * 100), logger)
|
acc.get_accumulated_value() * 100), logger)
|
||||||
|
|
||||||
|
|
||||||
def train():
|
def train():
|
||||||
model, train_dataloader, test_dataloader, criterion, \
|
# init dist
|
||||||
optimizer, schedule, lr_scheduler = initialize(CONFIG_PATH)
|
engine, train_dataloader, test_dataloader = colossalai.initialize(CONFIG_PATH)
|
||||||
|
|
||||||
logger = get_global_dist_logger()
|
logger = get_global_dist_logger()
|
||||||
|
|
||||||
engine = Engine(model=model,
|
|
||||||
train_dataloader=train_dataloader,
|
|
||||||
test_dataloader=test_dataloader,
|
|
||||||
criterion=criterion,
|
|
||||||
optimizer=optimizer,
|
|
||||||
lr_scheduler=lr_scheduler,
|
|
||||||
schedule=schedule)
|
|
||||||
logger.info("Engine is built", ranks=[0])
|
logger.info("Engine is built", ranks=[0])
|
||||||
|
|
||||||
trainer = Trainer(engine=engine, hooks_cfg=gpc.config.hooks, verbose=True)
|
trainer = Trainer(engine=engine, verbose=True)
|
||||||
logger.info("Trainer is built", ranks=[0])
|
logger.info("Trainer is built", ranks=[0])
|
||||||
|
|
||||||
logger.info("Train start", ranks=[0])
|
logger.info("Train start", ranks=[0])
|
||||||
trainer.fit(train_dataloader=train_dataloader,
|
trainer.fit(train_dataloader=train_dataloader,
|
||||||
test_dataloader=test_dataloader,
|
test_dataloader=test_dataloader,
|
||||||
max_epochs=gpc.config.num_epochs,
|
epochs=gpc.config.num_epochs,
|
||||||
|
hooks_cfg=gpc.config.hooks,
|
||||||
display_progress=True,
|
display_progress=True,
|
||||||
test_interval=1)
|
test_interval=1)
|
||||||
|
|
||||||
|
|
|
@ -3,6 +3,7 @@ from pathlib import Path
|
||||||
|
|
||||||
BATCH_SIZE = 128
|
BATCH_SIZE = 128
|
||||||
IMG_SIZE = 32
|
IMG_SIZE = 32
|
||||||
|
num_epochs = 200
|
||||||
|
|
||||||
# resnet 50
|
# resnet 50
|
||||||
model = dict(
|
model = dict(
|
||||||
|
@ -77,18 +78,14 @@ hooks = [
|
||||||
dict(type='AccuracyHook'),
|
dict(type='AccuracyHook'),
|
||||||
dict(type='LossHook'),
|
dict(type='LossHook'),
|
||||||
dict(type='TensorboardHook', log_dir='./tfb_logs'),
|
dict(type='TensorboardHook', log_dir='./tfb_logs'),
|
||||||
|
dict(
|
||||||
|
type='LRSchedulerHook',
|
||||||
|
by_epoch=True,
|
||||||
|
lr_scheduler_cfg=dict(
|
||||||
|
type='CosineAnnealingLR',
|
||||||
|
warmup_steps=5
|
||||||
|
)
|
||||||
|
),
|
||||||
dict(type='SaveCheckpointHook', interval=5, checkpoint_dir='./ckpt'),
|
dict(type='SaveCheckpointHook', interval=5, checkpoint_dir='./ckpt'),
|
||||||
# dict(type='LoadCheckpointHook', epoch=20, checkpoint_dir='./ckpt')
|
|
||||||
]
|
]
|
||||||
|
|
||||||
# fp16 = dict(
|
|
||||||
# mode=AMP_TYPE.PARALLEL,
|
|
||||||
# initial_scale=1
|
|
||||||
# )
|
|
||||||
|
|
||||||
lr_scheduler = dict(
|
|
||||||
type='CosineAnnealingLR',
|
|
||||||
T_max=200
|
|
||||||
)
|
|
||||||
|
|
||||||
num_epochs = 200
|
|
||||||
|
|
|
@ -11,6 +11,7 @@ NUM_ATTENTION_HEADS = 8
|
||||||
SUMMA_DIM = 2
|
SUMMA_DIM = 2
|
||||||
NUM_CLASSES = 10
|
NUM_CLASSES = 10
|
||||||
DEPTH = 6
|
DEPTH = 6
|
||||||
|
num_epochs = 60
|
||||||
|
|
||||||
train_data = dict(
|
train_data = dict(
|
||||||
dataset=dict(type='CIFAR10Dataset',
|
dataset=dict(type='CIFAR10Dataset',
|
||||||
|
@ -52,13 +53,6 @@ optimizer = dict(type='Adam', lr=0.001, weight_decay=0)
|
||||||
|
|
||||||
loss = dict(type='CrossEntropyLoss2D', )
|
loss = dict(type='CrossEntropyLoss2D', )
|
||||||
|
|
||||||
# model = dict(
|
|
||||||
# type='VanillaResNet',
|
|
||||||
# block_type='ResNetBasicBlock',
|
|
||||||
# layers=[2, 2, 2, 2],
|
|
||||||
# num_cls=10
|
|
||||||
# )
|
|
||||||
|
|
||||||
model = dict(
|
model = dict(
|
||||||
type='VisionTransformerFromConfig',
|
type='VisionTransformerFromConfig',
|
||||||
tensor_splitting_cfg=dict(type='ViTInputSplitter2D', ),
|
tensor_splitting_cfg=dict(type='ViTInputSplitter2D', ),
|
||||||
|
@ -114,8 +108,15 @@ hooks = [
|
||||||
dict(type='Accuracy2DHook'),
|
dict(type='Accuracy2DHook'),
|
||||||
dict(type='LossHook'),
|
dict(type='LossHook'),
|
||||||
dict(type='TensorboardHook', log_dir='./tfb_logs'),
|
dict(type='TensorboardHook', log_dir='./tfb_logs'),
|
||||||
|
dict(
|
||||||
|
type='LRSchedulerHook',
|
||||||
|
by_epoch=True,
|
||||||
|
lr_scheduler_cfg=dict(
|
||||||
|
type='LinearWarmupLR',
|
||||||
|
warmup_steps=5
|
||||||
|
)
|
||||||
|
),
|
||||||
dict(type='SaveCheckpointHook', interval=5, checkpoint_dir='./ckpt'),
|
dict(type='SaveCheckpointHook', interval=5, checkpoint_dir='./ckpt'),
|
||||||
# dict(type='LoadCheckpointHook', epoch=20, checkpoint_dir='./ckpt')
|
|
||||||
]
|
]
|
||||||
|
|
||||||
parallel = dict(
|
parallel = dict(
|
||||||
|
@ -125,11 +126,8 @@ parallel = dict(
|
||||||
|
|
||||||
fp16 = dict(mode=AMP_TYPE.PARALLEL, initial_scale=2 ** 8)
|
fp16 = dict(mode=AMP_TYPE.PARALLEL, initial_scale=2 ** 8)
|
||||||
|
|
||||||
lr_scheduler = dict(type='LinearWarmupLR', warmup_epochs=5)
|
engine = dict(
|
||||||
|
schedule=dict(num_microbatches=1)
|
||||||
schedule = dict(num_microbatches=1)
|
)
|
||||||
|
|
||||||
num_epochs = 60
|
|
||||||
num_microbatches = 1
|
|
||||||
|
|
||||||
logging = dict(root_path='./logs')
|
logging = dict(root_path='./logs')
|
||||||
|
|
|
@ -1,25 +1,16 @@
|
||||||
import colossalai
|
import colossalai
|
||||||
from colossalai.core import global_context as gpc
|
from colossalai.core import global_context as gpc
|
||||||
from colossalai.engine import Engine
|
|
||||||
from colossalai.logging import get_global_dist_logger
|
from colossalai.logging import get_global_dist_logger
|
||||||
from colossalai.trainer import Trainer
|
from colossalai.trainer import Trainer
|
||||||
|
|
||||||
|
|
||||||
def test_trainer():
|
def test_trainer():
|
||||||
model, train_dataloader, test_dataloader, criterion, optimizer, schedule, lr_scheduler = colossalai.initialize()
|
engine, train_dataloader, test_dataloader = colossalai.initialize()
|
||||||
logger = get_global_dist_logger()
|
logger = get_global_dist_logger()
|
||||||
|
|
||||||
engine = Engine(
|
|
||||||
model=model,
|
|
||||||
criterion=criterion,
|
|
||||||
optimizer=optimizer,
|
|
||||||
lr_scheduler=lr_scheduler,
|
|
||||||
schedule=schedule
|
|
||||||
)
|
|
||||||
logger.info("engine is built", ranks=[0])
|
logger.info("engine is built", ranks=[0])
|
||||||
|
|
||||||
trainer = Trainer(engine=engine,
|
trainer = Trainer(engine=engine,
|
||||||
hooks_cfg=gpc.config.hooks,
|
|
||||||
verbose=True)
|
verbose=True)
|
||||||
logger.info("trainer is built", ranks=[0])
|
logger.info("trainer is built", ranks=[0])
|
||||||
|
|
||||||
|
@ -27,7 +18,8 @@ def test_trainer():
|
||||||
trainer.fit(
|
trainer.fit(
|
||||||
train_dataloader=train_dataloader,
|
train_dataloader=train_dataloader,
|
||||||
test_dataloader=test_dataloader,
|
test_dataloader=test_dataloader,
|
||||||
max_epochs=gpc.config.num_epochs,
|
hooks_cfg=gpc.config.hooks,
|
||||||
|
epochs=gpc.config.num_epochs,
|
||||||
display_progress=False,
|
display_progress=False,
|
||||||
test_interval=5
|
test_interval=5
|
||||||
)
|
)
|
||||||
|
|
|
@ -18,14 +18,16 @@ level = os.environ['LEVEL']
|
||||||
CONFIG_PATH = Path(__file__).parent.parent.joinpath(f'configs/vit_2d_zero{level}.py')
|
CONFIG_PATH = Path(__file__).parent.parent.joinpath(f'configs/vit_2d_zero{level}.py')
|
||||||
|
|
||||||
|
|
||||||
def eval(engine):
|
def eval_epoch(engine: Engine, test_dataloader):
|
||||||
engine.eval()
|
engine.eval()
|
||||||
accumulated_loss = 0
|
accumulated_loss = 0
|
||||||
correct_sum = 0
|
correct_sum = 0
|
||||||
total_sum = 0
|
total_sum = 0
|
||||||
|
num_steps = len(test_dataloader)
|
||||||
|
data_iter = iter(test_dataloader)
|
||||||
|
|
||||||
for i in range(engine.schedule.num_steps):
|
for i in range(num_steps):
|
||||||
output, label, loss = engine.step()
|
output, label, loss = engine.step(data_iter)
|
||||||
accumulated_loss += loss.detach().cpu().numpy()
|
accumulated_loss += loss.detach().cpu().numpy()
|
||||||
|
|
||||||
output = _gather(
|
output = _gather(
|
||||||
|
@ -42,18 +44,19 @@ def eval(engine):
|
||||||
correct = torch.sum(label[0] == output)
|
correct = torch.sum(label[0] == output)
|
||||||
correct_sum += correct
|
correct_sum += correct
|
||||||
total_sum += label[0].size(0)
|
total_sum += label[0].size(0)
|
||||||
avg_loss = accumulated_loss / engine.schedule.num_steps
|
avg_loss = accumulated_loss / num_steps
|
||||||
return correct_sum, total_sum, avg_loss
|
return correct_sum, total_sum, avg_loss
|
||||||
|
|
||||||
|
|
||||||
def train(engine):
|
def train_epoch(engine, train_dataloader):
|
||||||
engine.train()
|
engine.train()
|
||||||
accumulated_loss = 0
|
accumulated_loss = 0
|
||||||
|
num_steps = len(train_dataloader)
|
||||||
for i in range(engine.schedule.num_steps):
|
data_iter = iter(train_dataloader)
|
||||||
output, label, loss = engine.step()
|
for i in range(num_steps):
|
||||||
|
output, label, loss = engine.step(data_iter)
|
||||||
accumulated_loss += loss.detach().cpu().numpy()
|
accumulated_loss += loss.detach().cpu().numpy()
|
||||||
avg_loss = accumulated_loss / engine.schedule.num_steps
|
avg_loss = accumulated_loss / num_steps
|
||||||
return avg_loss
|
return avg_loss
|
||||||
|
|
||||||
|
|
||||||
|
@ -61,30 +64,17 @@ def train(engine):
|
||||||
@pytest.mark.skip("This test should be invoked by test.sh in the same folder as it runs on multiple gpus")
|
@pytest.mark.skip("This test should be invoked by test.sh in the same folder as it runs on multiple gpus")
|
||||||
def test_2d_parallel_vision_transformer():
|
def test_2d_parallel_vision_transformer():
|
||||||
# init dist
|
# init dist
|
||||||
model, train_dataloader, test_dataloader, criterion, optimizer, schedule, lr_scheduler = colossalai.initialize(
|
engine, train_dataloader, test_dataloader = colossalai.initialize(CONFIG_PATH)
|
||||||
CONFIG_PATH)
|
|
||||||
logger = get_global_dist_logger()
|
logger = get_global_dist_logger()
|
||||||
|
|
||||||
engine = Engine(model=model,
|
|
||||||
train_dataloader=train_dataloader,
|
|
||||||
test_dataloader=test_dataloader,
|
|
||||||
criterion=criterion,
|
|
||||||
optimizer=optimizer,
|
|
||||||
lr_scheduler=lr_scheduler,
|
|
||||||
schedule=schedule)
|
|
||||||
|
|
||||||
# for param in model.parameters():
|
|
||||||
# if isinstance(param, torch.HalfTensor):
|
|
||||||
# print(param.shape)
|
|
||||||
|
|
||||||
logger.info('start training')
|
logger.info('start training')
|
||||||
for epoch in range(gpc.config.num_epochs):
|
for epoch in range(gpc.config.num_epochs):
|
||||||
train_loss = train(engine)
|
train_loss = train_epoch(engine, train_dataloader)
|
||||||
|
|
||||||
logger.info(f'epoch {epoch} - train loss: {train_loss}')
|
logger.info(f'epoch {epoch} - train loss: {train_loss}')
|
||||||
|
|
||||||
if epoch % 2 == 0:
|
if epoch % 2 == 0:
|
||||||
correct_sum, total_sum, eval_loss = eval(engine)
|
correct_sum, total_sum, eval_loss = eval_epoch(engine, test_dataloader)
|
||||||
logger.info(
|
logger.info(
|
||||||
f'epoch {epoch} - eval loss: {eval_loss}, total: {total_sum}, '
|
f'epoch {epoch} - eval loss: {eval_loss}, total: {total_sum}, '
|
||||||
f'correct: {correct_sum}, acc: {correct_sum / total_sum}')
|
f'correct: {correct_sum}, acc: {correct_sum / total_sum}')
|
||||||
|
|
Loading…
Reference in New Issue