add pytorch hooks (#179)

* add pytorch hooks
fix #175

* remove licenses in src code

* add gpu memory tracer

* replacing print with logger in ophooks.
pull/200/head
Jiarui Fang 2022-01-25 22:20:54 +08:00 committed by GitHub
parent 708404d5f8
commit 569357fea0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 479 additions and 134 deletions

View File

@ -1,10 +1,12 @@
from .builder import (build_schedule, build_lr_scheduler, build_model, build_optimizer, build_layer,
build_loss, build_hooks, build_dataset, build_transform, build_data_sampler,
build_gradient_handler)
from .builder import (build_schedule, build_lr_scheduler, build_model,
build_optimizer, build_layer, build_loss, build_hooks,
build_dataset, build_transform, build_data_sampler,
build_gradient_handler, build_ophooks)
from .pipeline import build_pipeline_model, build_pipeline_model_from_cfg
__all__ = [
'build_schedule', 'build_lr_scheduler', 'build_model', 'build_optimizer',
'build_layer', 'build_loss', 'build_hooks', 'build_dataset', 'build_transform', 'build_data_sampler',
'build_gradient_handler', 'build_pipeline_model', 'build_pipeline_model_from_cfg'
'build_layer', 'build_loss', 'build_hooks', 'build_dataset',
'build_transform', 'build_data_sampler', 'build_gradient_handler',
'build_pipeline_model', 'build_pipeline_model_from_cfg', 'build_ophooks'
]

View File

@ -27,7 +27,7 @@ def build_from_registry(config, registry: Registry):
"""Returns an object constructed from `config`, the type of the object
is specified by `registry`.
:param config: A python dict or a :class:`colossalai.context.Config` object
:param config: A python dict or a :class:`colossalai.context.Config` object
containing information used in the construction of the return object
:type config: dict or :class:`colossalai.context.colossalai.context.Config`
:param registry: A registry specifying the type of the return object
@ -50,7 +50,8 @@ def build_from_registry(config, registry: Registry):
obj = registry.get_module(mod_type)(**config_)
except Exception as e:
print(
f'An error occurred when building {mod_type} from registry {registry.name}', flush=True)
f'An error occurred when building {mod_type} from registry {registry.name}',
flush=True)
raise e
return obj
@ -69,7 +70,7 @@ def build_layer(config):
def build_loss(config):
"""Returns a loss function object of :class:`torch.autograd.Function` constructed
"""Returns a loss function object of :class:`torch.autograd.Function` constructed
from `config`.
:param config: A python dict or a :class:`colossalai.context.Config` object
@ -94,7 +95,7 @@ def build_model(config):
def build_dataset(config):
"""Returns a dataset object of :class:`torch.utils.data.Dataset` constructed
"""Returns a dataset object of :class:`torch.utils.data.Dataset` constructed
from `config`.
:param config: A python dict or a :class:`colossalai.context.Config` object
@ -107,13 +108,13 @@ def build_dataset(config):
def build_optimizer(config, model):
"""Returns an optimizer object of :class:`torch.optim.Optimizer` constructed from `config`,
"""Returns an optimizer object of :class:`torch.optim.Optimizer` constructed from `config`,
'model' and 'params'.
:param config: A python dict or a :class:`colossalai.context.Config` object
:param config: A python dict or a :class:`colossalai.context.Config` object
containing information used in the construction of the return object
:type config: dict or :class:`colossalai.context.Config`
:param model: A model containing parameters for the optimizer
:param model: A model containing parameters for the optimizer
:type model: :class:`nn.Module`
:return: An object of :class:`torch.optim.Optimizer`
:rtype: :class:`torch.optim.Optimizer`
@ -159,6 +160,19 @@ def build_hooks(config, trainer):
return build_from_registry(config_, HOOKS)
def build_ophooks(config):
"""Returns a hook object of :class:`BaseOpHook` constructed from `config`.
:param config: A python dict or a :class:`colossalai.context.Config` object
containing information used in the construction of the return object
:type config: dict or :class:`colossalai.context.Config`
:return: An object of :class:`colossalai.trainer.hooks.BaseOpHook`
:rtype: :class:`colossalai.trainer.hooks.BaseOpHook`
"""
config_ = config.copy()
return build_from_registry(config_, OPHOOKS)
def build_transform(config):
"""Returns a transformation object of :class:`torchvision.transforms` constructed
from `config`.
@ -191,10 +205,10 @@ def build_data_sampler(config, dataset):
def build_lr_scheduler(config, optimizer):
"""Returns a learning rate scheduler object of :class:`torch.optim.lr_scheduler`
"""Returns a learning rate scheduler object of :class:`torch.optim.lr_scheduler`
constructed from `config`, `optimizer`, `total_steps` and `num_steps_per_epoch`.
:param config: A python dict or a :class:`colossalai.context.Config` object
:param config: A python dict or a :class:`colossalai.context.Config` object
containing information used in the construction of the return object
:type config: dict or :class:`colossalai.context.Config`
:param optimizer: An optimizer object containing parameters for the learning rate

View File

@ -1,7 +1,6 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from typing import List
from torch.nn import Module
from torch.nn.modules.loss import _Loss
@ -9,10 +8,11 @@ from torch.optim import Optimizer
from colossalai.logging import get_dist_logger
from torch import Tensor
from colossalai.engine.ophooks import register_ophooks_recursively, BaseOpHook
class Engine:
"""Basic engine class for training and evaluation. It runs a specific process method
"""Basic engine class for training and evaluation. It runs a specific process method
:meth:`step` which is based on the given :attr:`schedule` over each batch of a dataset.
It controls a iteration in training.
@ -29,15 +29,14 @@ class Engine:
:param verbose: whether to display log info
:type verbose: bool
"""
def __init__(self,
model: Module,
optimizer: Optimizer,
criterion: _Loss,
gradient_handlers: List = None,
clip_grad_norm: float = 0.0,
verbose: bool = True
):
ophook_list: List[BaseOpHook] = [],
verbose: bool = True):
self._model = model
self._optimizer = optimizer
self._criterion = criterion
@ -54,6 +53,9 @@ class Engine:
else:
self._gradient_handlers = []
self._ophook_list = ophook_list
register_ophooks_recursively(self._model, self._ophook_list)
@property
def model(self):
"""Model attached to the engine"""
@ -87,7 +89,10 @@ class Engine:
:param loss: Loss value computed by a loss function
:type loss: :class:`torch.Tensor`
"""
return self.optimizer.backward(loss)
ret = self.optimizer.backward(loss)
for ophook in self._ophook_list:
ophook.post_iter()
return ret
def backward_by_grad(self, tensor, grad):
"""Start backward propagation given the gradient of the output tensor
@ -97,7 +102,10 @@ class Engine:
:param grad: Gradient passed back to the output
:type grad: :class:`torch.Tensor`
"""
return self.optimizer.backward_by_grad(tensor, grad)
ret = self.optimizer.backward_by_grad(tensor, grad)
for ophook in self._ophook_list:
ophook.post_iter()
return ret
def calc_loss(self, *args, **kwargs):
"""Compute the loss value

View File

@ -0,0 +1,115 @@
from ._base_ophook import BaseOpHook
from ._memtracer_ophook import MemTracerOpHook
import torch
from typing import List
all = ["BaseOpHook", "MemTracerOpHook", "register_ophooks_recursively"]
# apply torch.autograd.Function that calls a backward_function to tensors in output
def _apply_to_tensors_only(module, functional, backward_function, outputs):
if type(outputs) is tuple:
touched_outputs = []
for output in outputs:
touched_output = _apply_to_tensors_only(module, functional,
backward_function, output)
touched_outputs.append(touched_output)
return tuple(touched_outputs)
elif type(outputs) is torch.Tensor:
return functional.apply(module, backward_function, outputs)
else:
return outputs
class PreBackwardFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, module, pre_backward_function, outputs):
ctx.module = module
ctx.pre_backward_function = pre_backward_function
module.applied_pre_backward = False
outputs = outputs.detach()
return outputs
@staticmethod
def backward(ctx, *args):
ctx.pre_backward_function(ctx.module)
return (None, None) + args
class PostBackwardFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, module, pre_backward_function, output):
ctx.module = module
output = output.detach()
ctx.pre_backward_function = pre_backward_function
return output
@staticmethod
def backward(ctx, *args):
"""
Args:
activation_grad of the next layer.
Returns:
grad of the input activation.
"""
ctx.pre_backward_function(ctx.module)
return (None, None) + args
def register_ophooks_recursively(module: torch.nn.Module,
ophook_list: List[BaseOpHook] = None,
name: str = ""):
r"""Recursilvely register pre/post hooks for all submodules in the module in FWD and BWD."""
assert isinstance(module, torch.nn.Module)
has_children = False
for child_name, child in module.named_children():
register_ophooks_recursively(child, ophook_list, name + child_name)
has_children = True
# Early return on modules with no parameters or buffers that
# are not in their children.
if (len(list(module.named_parameters(recurse=False))) == 0
and len(list(module.named_buffers(recurse=False))) == 0):
return
# return if the module has not childern.
if has_children:
return
if ophook_list is not None:
for hook in ophook_list:
assert (isinstance(hook, BaseOpHook))
def _pre_forward_module_hook(submodule, *args):
for hook in ophook_list:
assert isinstance(submodule, torch.nn.Module)
hook.pre_fwd_exec(submodule, *args)
def _post_forward_module_hook(submodule, *args):
for hook in ophook_list:
assert isinstance(submodule, torch.nn.Module)
hook.post_fwd_exec(submodule, *args)
def _pre_backward_module_hook(submodule, inputs, output):
def _run_before_backward_function(submodule):
for hook in ophook_list:
assert isinstance(submodule, torch.nn.Module)
hook.pre_bwd_exec(submodule, inputs, output)
return _apply_to_tensors_only(submodule, PreBackwardFunction,
_run_before_backward_function, output)
def _post_backward_module_hook(submodule, inputs):
def _run_after_backward_function(submodule):
for hook in ophook_list:
assert isinstance(submodule, torch.nn.Module)
hook.post_bwd_exec(submodule, inputs)
return _apply_to_tensors_only(submodule, PostBackwardFunction,
_run_after_backward_function, inputs)
module.register_forward_pre_hook(_pre_forward_module_hook)
module.register_forward_hook(_post_forward_module_hook)
module.register_forward_hook(_pre_backward_module_hook)
module.register_forward_pre_hook(_post_backward_module_hook)

View File

@ -0,0 +1,29 @@
from abc import ABC, abstractmethod
import torch
class BaseOpHook(ABC):
"""This class allows users to add customized operations
before and after the execution of a PyTorch submodule"""
def __init__(self):
pass
@abstractmethod
def pre_fwd_exec(self, module: torch.nn.Module, *args):
pass
@abstractmethod
def post_fwd_exec(self, module: torch.nn.Module, *args):
pass
@abstractmethod
def pre_bwd_exec(self, module: torch.nn.Module, input, output):
pass
@abstractmethod
def post_bwd_exec(self, module: torch.nn.Module, input):
pass
@abstractmethod
def post_iter(self):
pass

View File

@ -0,0 +1,131 @@
import torch
from . import BaseOpHook
from concurrent.futures import ThreadPoolExecutor
from colossalai.registry import OPHOOKS
from colossalai.logging import get_dist_logger
from time import sleep, time
import psutil
import pickle
def get_cuda_memory_used(device):
"""
Get the free memory info of device.
Notice that for CPU, this function will return 1/N of the total free memory,
where N is the world size.
"""
ret = torch.cuda.memory_allocated()
# get the peak memory to report correct data, so reset the counter for the next call
if hasattr(torch.cuda, "reset_peak_memory_stats"): # pytorch 1.4+
torch.cuda.reset_peak_memory_stats()
return ret
class AsyncMemoryMonitor:
def __init__(self, power=10):
"""
An Async Mem Monitor runing during computing.
Sampling GPU memory usage of the current GPU dev
at interval of 1/(10**power) sec.
"""
self.keep_measuring = False
self.executor = ThreadPoolExecutor(max_workers=1)
self.monitor_thread = None
self.interval = 1 / (10**power)
self.time_stamps = []
self.mem_stats = []
def set_interval(self, power: int):
self.interval = 1 / (10**power)
def is_measuring(self):
return self.keep_measuring
def start(self):
self.keep_measuring = True
self.monitor_thread = self.executor.submit(self._measure_usage)
def finish(self):
if self.keep_measuring is False:
return 0
self.keep_measuring = False
max_usage = self.monitor_thread.result()
self.monitor_thread = None
self.time_stamps.append(time())
self.mem_stats.append(max_usage)
return max_usage
def _measure_usage(self):
max_usage = 0
dev = torch.device(f"cuda:{torch.cuda.current_device()}")
while self.keep_measuring:
max_usage = max(
max_usage,
get_cuda_memory_used(dev),
)
sleep(self.interval)
return max_usage
def state_dict(self):
return {
"time_stamps": self.time_stamps,
"mem_stats": self.mem_stats,
}
def save(self, filename):
with open(filename, "wb") as f:
pickle.dump(self.state_dict(), f)
@OPHOOKS.register_module
class MemTracerOpHook(BaseOpHook):
def __init__(self, niter=5):
super().__init__()
self.async_mem_monitor = AsyncMemoryMonitor()
self._niter = niter
self._curiter = 0
self._logger = get_dist_logger()
def _isvalid(self, module):
return module.training and self._curiter < self._niter
def niter(self):
return self._niter
def pre_fwd_exec(self, module: torch.nn.Module, *args):
if self._isvalid(module):
self.async_mem_monitor.finish()
self.async_mem_monitor.start()
self._logger.debug(f'FWD PRE {module.__class__.__name__}')
def post_fwd_exec(self, module: torch.nn.Module, *args):
if self._isvalid(module):
self.async_mem_monitor.finish()
self._logger.debug(f'FWD POST {module.__class__.__name__}')
def pre_bwd_exec(self, module: torch.nn.Module, input, output):
assert isinstance(module, torch.nn.Module)
if self._isvalid(module):
self.async_mem_monitor.finish()
self.async_mem_monitor.start()
self._logger.debug(f'BWD PRE {module.__class__.__name__}')
def post_bwd_exec(self, module: torch.nn.Module, input):
assert isinstance(module, torch.nn.Module)
if self._isvalid(module):
self.async_mem_monitor.finish()
self._logger.debug(f'BWD POST {module.__class__.__name__}')
def pre_iter(self):
pass
def post_iter(self):
self.async_mem_monitor.finish()
if self._curiter == self._niter:
self._logger.info(
f'dump a memory statistics as pickle to ./memstats.pkl')
self.save_results("memstats.pkl")
self._curiter += 1
def save_results(self, filename):
self.async_mem_monitor.save(filename)

View File

@ -7,16 +7,17 @@ from torchvision import transforms
from .registry import Registry
LAYERS = Registry('layers', third_party_library=[nn])
LOSSES = Registry('losses')
MODELS = Registry('models', third_party_library=[tv_models])
OPTIMIZERS = Registry('optimizers', third_party_library=[optim, dist_optim])
DATASETS = Registry('datasets', third_party_library=[tv_datasets])
DIST_GROUP_INITIALIZER = Registry('dist_group_initializer')
GRADIENT_HANDLER = Registry('gradient_handler')
LOSSES = Registry('losses', third_party_library=[nn])
HOOKS = Registry('hooks')
TRANSFORMS = Registry('transforms', third_party_library=[transforms])
DATA_SAMPLERS = Registry('data_samplers')
LR_SCHEDULERS = Registry('lr_schedulers')
SCHEDULE = Registry('schedules')
LAYERS = Registry("layers", third_party_library=[nn])
LOSSES = Registry("losses")
MODELS = Registry("models", third_party_library=[tv_models])
OPTIMIZERS = Registry("optimizers", third_party_library=[optim, dist_optim])
DATASETS = Registry("datasets", third_party_library=[tv_datasets])
DIST_GROUP_INITIALIZER = Registry("dist_group_initializer")
GRADIENT_HANDLER = Registry("gradient_handler")
LOSSES = Registry("losses", third_party_library=[nn])
HOOKS = Registry("hooks")
TRANSFORMS = Registry("transforms", third_party_library=[transforms])
DATA_SAMPLERS = Registry("data_samplers")
LR_SCHEDULERS = Registry("lr_schedulers")
SCHEDULE = Registry("schedules")
OPHOOKS = Registry("ophooks")

View File

@ -1,8 +1,4 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from typing import Union, List
from colossalai import engine
from colossalai.context.parallel_mode import ParallelMode
import torch
@ -11,17 +7,18 @@ from torch.utils.data import DataLoader
from tqdm import tqdm
from colossalai.core import global_context as gpc
from colossalai.engine import Engine
from colossalai.engine.schedule import NonPipelineSchedule, BaseSchedule
from colossalai.logging import DistributedLogger
from colossalai.utils import MultiTimer
from colossalai.utils import is_dp_rank_0, is_tp_rank_0, is_no_pp_or_last_stage
from .hooks import BaseHook
from colossalai.trainer.hooks import BaseHook
class Trainer:
"""This a class tending for easy deployments of users' training and evaluation instead of
writing their own scripts. It is similar with ``ignite.engine`` and ``keras.engine``, but is
"""This a class tending for easy deployments of users' training and evaluation instead of
writing their own scripts. It is similar with ``ignite.engine`` and ``keras.engine``, but is
called `Trainer`.
:param engine: Engine responsible for the process function
@ -33,12 +30,13 @@ class Trainer:
:param logger: Logger used to record the whole training
:type logger: :class:`colossalai.logging.DistributedLogger`, optional
"""
def __init__(self,
engine: Engine,
schedule: BaseSchedule = None,
timer: MultiTimer = None,
logger: DistributedLogger = None):
def __init__(
self,
engine: Engine,
schedule: BaseSchedule = None,
timer: MultiTimer = None,
logger: DistributedLogger = None,
):
# training-ralated params
self._engine = engine
self._max_epochs = 0
@ -63,29 +61,28 @@ class Trainer:
# set schedule which specifies the training iteration for the engine
if schedule is None:
schedule = NonPipelineSchedule()
if gpc.is_initialized(ParallelMode.PIPELINE) and gpc.get_world_size(ParallelMode.PIPELINE) > 1:
assert not isinstance(schedule, NonPipelineSchedule), \
'NonPipelineSchedule cannot be used for pipeline parallel training, please use PipelineSchedule instead.'
if (gpc.is_initialized(ParallelMode.PIPELINE)
and gpc.get_world_size(ParallelMode.PIPELINE) > 1):
assert not isinstance(
schedule, NonPipelineSchedule
), "NonPipelineSchedule cannot be used for pipeline parallel training, please use PipelineSchedule instead."
self._schedule = schedule
self._schedule.pre_processing(engine)
@property
def cur_epoch(self):
"""Returns the index of the current epoch.
"""
"""Returns the index of the current epoch."""
return self._cur_epoch
@cur_epoch.setter
def cur_epoch(self, epoch: int):
"""Set how many epochs have been processed.
"""
"""Set how many epochs have been processed."""
# allow setter for training resumption
self._cur_epoch = epoch
@property
def cur_step(self):
"""Returns how many iteration steps have been processed.
"""
"""Returns how many iteration steps have been processed."""
return self._cur_step
@property
@ -131,8 +128,7 @@ class Trainer:
getattr(self._timer, action)(item, *args, **kwargs)
def _reset_states(self) -> None:
"""Clear trainer states
"""
"""Clear trainer states"""
self.states = dict()
def _call_hooks(self, func, output=None):
@ -152,99 +148,122 @@ class Trainer:
@staticmethod
def _should_display_progress(display_progress: bool):
""" Only display progress on DP rank 0, TP rank 0 and PP last rank
"""
return display_progress and is_dp_rank_0() and is_tp_rank_0() and is_no_pp_or_last_stage()
"""Only display progress on DP rank 0, TP rank 0 and PP last rank"""
return (display_progress and is_dp_rank_0() and is_tp_rank_0()
and is_no_pp_or_last_stage())
def _train_epoch(self,
train_dataloader: DataLoader,
epoch: int = None,
display_progress: bool = False,
return_output_label: bool = True):
def _train_epoch(
self,
train_dataloader: DataLoader,
epoch: int = None,
display_progress: bool = False,
return_output_label: bool = True,
):
# set training state
self._engine.train()
data_iter = iter(train_dataloader)
progress = range(self._steps_per_epoch)
if display_progress:
if epoch is None:
progress = tqdm(progress, desc='[Train]')
progress = tqdm(progress, desc="[Train]")
else:
progress = tqdm(progress, desc=f'[Epoch {epoch} / Train]')
progress = tqdm(progress, desc=f"[Epoch {epoch} / Train]")
self._call_hooks('before_train_epoch')
self._call_timer(action='start', item='Train-epoch')
self._call_hooks("before_train_epoch")
self._call_timer(action="start", item="Train-epoch")
for i in progress:
self._call_hooks('before_train_iter')
self._call_timer(action='start', item='Train-step')
self._call_hooks("before_train_iter")
self._call_timer(action="start", item="Train-step")
# run 1 training step
self.engine.zero_grad()
logits, label, loss = self.schedule.forward_backward_step(
self.engine, data_iter, forward_only=False, return_loss=True, return_output_label=return_output_label)
self.engine,
data_iter,
forward_only=False,
return_loss=True,
return_output_label=return_output_label,
)
self.engine.step()
self._call_timer(action='stop', item='Train-step', keep_in_history=True)
self._call_hooks('after_train_iter', output=(logits, label, loss))
self._call_timer(action="stop",
item="Train-step",
keep_in_history=True)
self._call_hooks("after_train_iter", output=(logits, label, loss))
self._cur_step += 1
if display_progress:
if 'step_metrics' in self.states:
progress.set_postfix(**self.states['step_metrics'])
if "step_metrics" in self.states:
progress.set_postfix(**self.states["step_metrics"])
# stop when max iter is reached
if self._exceed_max_step():
break
self._call_timer(action='stop', item='Train-epoch', keep_in_history=True)
self._call_hooks('after_train_epoch')
self._call_timer(action='reset', item='Train-epoch')
self._call_timer(action="stop",
item="Train-epoch",
keep_in_history=True)
self._call_hooks("after_train_epoch")
self._call_timer(action="reset", item="Train-epoch")
def _eval(self,
test_dataloader: DataLoader,
epoch: int = None,
display_progress: bool = False,
return_output_label: bool = True):
def _eval(
self,
test_dataloader: DataLoader,
epoch: int = None,
display_progress: bool = False,
return_output_label: bool = True,
):
# switch engine status
self._engine.eval()
data_iter = iter(test_dataloader)
num_steps = len(test_dataloader)
self._call_hooks('before_test')
self._call_hooks("before_test")
# prepare progress bar
progress = range(num_steps)
if display_progress:
desc = 'Evaluation'
desc = "Evaluation"
if epoch is not None:
desc = '[Epoch %d / Test]' % epoch
desc = "[Epoch %d / Test]" % epoch
progress = tqdm(progress, desc=desc)
self._call_hooks('before_test_epoch')
self._call_timer(action='start', item='Test-epoch')
self._call_hooks("before_test_epoch")
self._call_timer(action="start", item="Test-epoch")
with torch.no_grad():
for _ in progress:
self._call_hooks('before_test_iter')
self._call_timer(action='start', item='Test-step')
self._call_hooks("before_test_iter")
self._call_timer(action="start", item="Test-step")
logits, label, loss = self.schedule.forward_backward_step(
self.engine, data_iter, forward_only=True, return_loss=True, return_output_label=return_output_label)
self._call_timer(action='stop', item='Test-step', keep_in_history=True)
self._call_hooks('after_test_iter',
self.engine,
data_iter,
forward_only=True,
return_loss=True,
return_output_label=return_output_label,
)
self._call_timer(action="stop",
item="Test-step",
keep_in_history=True)
self._call_hooks("after_test_iter",
output=(logits, label, loss))
if display_progress:
if 'step_metrics' in self.states:
progress.set_postfix(**self.states['step_metrics'])
if "step_metrics" in self.states:
progress.set_postfix(**self.states["step_metrics"])
self._call_timer(action='stop', item='Test-epoch', keep_in_history=True)
self._call_hooks('after_test_epoch')
self._call_hooks('after_test')
self._call_timer(action='reset', item='Test-step')
self._call_timer(action='reset', item='Test-epoch')
self._call_timer(action="stop",
item="Test-epoch",
keep_in_history=True)
self._call_hooks("after_test_epoch")
self._call_hooks("after_test")
self._call_timer(action="reset", item="Test-step")
self._call_timer(action="reset", item="Test-epoch")
def _exceed_max_step(self):
return self._max_steps is not None and self._cur_step >= self._max_steps
def fit(self,
def fit(
self,
train_dataloader: DataLoader,
epochs: int,
max_steps: int = None,
@ -253,7 +272,7 @@ class Trainer:
hooks: List[BaseHook] = None,
display_progress: bool = False,
return_output_label: bool = True,
):
):
"""Trains the model to fit training data.
:param train_dataloader: DataLoader in training
@ -290,7 +309,9 @@ class Trainer:
# reset hooks
self._reset_states()
if hooks is not None:
assert isinstance(hooks, list), f'expected argument hooks be to list, but got {type(hooks)}'
assert isinstance(
hooks, list
), f"expected argument hooks be to list, but got {type(hooks)}"
else:
hooks = []
self.hooks = hooks
@ -298,13 +319,16 @@ class Trainer:
if self._verbose:
for hook in self.hooks:
self._logger.info(
f'Using {hook.__class__.__name__} for training, priority = {hook.priority}', ranks=[0])
self._logger.info("Lower value means higher priority for calling hook function", ranks=[0])
self._call_hooks('after_hook_is_attached')
f"Using {hook.__class__.__name__} for training, priority = {hook.priority}",
ranks=[0],
)
self._logger.info(
"Lower value means higher priority for calling hook function",
ranks=[0])
self._call_hooks("after_hook_is_attached")
# start train
self._engine.train()
self._call_hooks('before_train')
self._call_hooks("before_train")
# recover step value if resuming training
last_epoch = self._cur_epoch
@ -317,16 +341,17 @@ class Trainer:
train_dataloader=train_dataloader,
epoch=epoch,
display_progress=display_progress,
return_output_label=return_output_label
return_output_label=return_output_label,
)
# start eval
if should_test and epoch % test_interval == 0:
self._eval(test_dataloader=test_dataloader,
display_progress=display_progress,
epoch=epoch,
return_output_label=return_output_label
)
self._eval(
test_dataloader=test_dataloader,
display_progress=display_progress,
epoch=epoch,
return_output_label=return_output_label,
)
self._cur_epoch += 1
@ -334,16 +359,19 @@ class Trainer:
if self._exceed_max_step():
self._logger.info(
f"Max number of steps {max_steps} has been reached, training is stopped automatically",
ranks=[0])
ranks=[0],
)
break
self._call_hooks('after_train')
self._call_timer('reset', 'Train-epoch')
self._call_hooks("after_train")
self._call_timer("reset", "Train-epoch")
def evaluate(self,
test_dataloader: DataLoader,
hooks: List[BaseHook] = None,
display_progress: bool = False,
return_output_label: bool = True):
def evaluate(
self,
test_dataloader: DataLoader,
hooks: List[BaseHook] = None,
display_progress: bool = False,
return_output_label: bool = True,
):
"""Evaluates the model with testing data.
:param test_dataloader: DataLoader in testing
@ -362,7 +390,9 @@ class Trainer:
# reset hooks
self._reset_states()
if hooks is not None:
assert isinstance(hooks, list), f'expected argument hooks be to list, but got {type(hooks)}'
assert isinstance(
hooks, list
), f"expected argument hooks be to list, but got {type(hooks)}"
else:
hooks = []
self.hooks = hooks
@ -370,15 +400,20 @@ class Trainer:
if self._verbose:
for hook in self.hooks:
self._logger.info(
f'Using {hook.__class__.__name__} for training, priority = {hook.priority}', ranks=[0])
self._logger.info("Lower value means higher priority for calling hook function", ranks=[0])
self._call_hooks('after_hook_is_attached')
f"Using {hook.__class__.__name__} for training, priority = {hook.priority}",
ranks=[0],
)
self._logger.info(
"Lower value means higher priority for calling hook function",
ranks=[0])
self._call_hooks("after_hook_is_attached")
# eval
self._eval(test_dataloader=test_dataloader,
display_progress=display_progress,
return_output_label=return_output_label
)
self._eval(
test_dataloader=test_dataloader,
display_progress=display_progress,
return_output_label=return_output_label,
)
def predict(self, data: Union[Tensor, List[Tensor]]):
"""Uses trained model to make a prediction for a tensor or a tensor list.
@ -399,6 +434,8 @@ class Trainer:
# for compatibility with schedule
simple_dataloader = [(data, None)]
data_iter = iter(simple_dataloader)
output, _, _ = self.schedule.forward_backward_step(
self.engine, data_iter, forward_only=True, return_loss=False)
output, _, _ = self.schedule.forward_backward_step(self.engine,
data_iter,
forward_only=True,
return_loss=False)
return output

View File

@ -6,6 +6,7 @@ from pathlib import Path
import pytest
from colossalai.context.config import Config
from colossalai.builder import build_ophooks
@pytest.mark.cpu
@ -17,3 +18,10 @@ def test_load_config():
assert config.train_data.dataset, 'cannot access grandchild attribute'
assert isinstance(config.train_data.dataset.transform_pipeline[0], dict), \
f'expected attribute transform_pipeline elements to be a dict, but found {type(config.train_data.dataset.transform_pipeline)}'
@pytest.mark.cpu
def test_load_ophooks():
dict = {'type': 'MemTracerOpHook', 'niter': 2}
ophook = build_ophooks(dict)
assert ophook.niter() == 2