add pytorch hooks (#179)

* add pytorch hooks
fix #175

* remove licenses in src code

* add gpu memory tracer

* replacing print with logger in ophooks.
pull/200/head
Jiarui Fang 2022-01-25 22:20:54 +08:00 committed by GitHub
parent 708404d5f8
commit 569357fea0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 479 additions and 134 deletions

View File

@ -1,10 +1,12 @@
from .builder import (build_schedule, build_lr_scheduler, build_model, build_optimizer, build_layer, from .builder import (build_schedule, build_lr_scheduler, build_model,
build_loss, build_hooks, build_dataset, build_transform, build_data_sampler, build_optimizer, build_layer, build_loss, build_hooks,
build_gradient_handler) build_dataset, build_transform, build_data_sampler,
build_gradient_handler, build_ophooks)
from .pipeline import build_pipeline_model, build_pipeline_model_from_cfg from .pipeline import build_pipeline_model, build_pipeline_model_from_cfg
__all__ = [ __all__ = [
'build_schedule', 'build_lr_scheduler', 'build_model', 'build_optimizer', 'build_schedule', 'build_lr_scheduler', 'build_model', 'build_optimizer',
'build_layer', 'build_loss', 'build_hooks', 'build_dataset', 'build_transform', 'build_data_sampler', 'build_layer', 'build_loss', 'build_hooks', 'build_dataset',
'build_gradient_handler', 'build_pipeline_model', 'build_pipeline_model_from_cfg' 'build_transform', 'build_data_sampler', 'build_gradient_handler',
'build_pipeline_model', 'build_pipeline_model_from_cfg', 'build_ophooks'
] ]

View File

@ -50,7 +50,8 @@ def build_from_registry(config, registry: Registry):
obj = registry.get_module(mod_type)(**config_) obj = registry.get_module(mod_type)(**config_)
except Exception as e: except Exception as e:
print( print(
f'An error occurred when building {mod_type} from registry {registry.name}', flush=True) f'An error occurred when building {mod_type} from registry {registry.name}',
flush=True)
raise e raise e
return obj return obj
@ -159,6 +160,19 @@ def build_hooks(config, trainer):
return build_from_registry(config_, HOOKS) return build_from_registry(config_, HOOKS)
def build_ophooks(config):
"""Returns a hook object of :class:`BaseOpHook` constructed from `config`.
:param config: A python dict or a :class:`colossalai.context.Config` object
containing information used in the construction of the return object
:type config: dict or :class:`colossalai.context.Config`
:return: An object of :class:`colossalai.trainer.hooks.BaseOpHook`
:rtype: :class:`colossalai.trainer.hooks.BaseOpHook`
"""
config_ = config.copy()
return build_from_registry(config_, OPHOOKS)
def build_transform(config): def build_transform(config):
"""Returns a transformation object of :class:`torchvision.transforms` constructed """Returns a transformation object of :class:`torchvision.transforms` constructed
from `config`. from `config`.

View File

@ -1,7 +1,6 @@
#!/usr/bin/env python #!/usr/bin/env python
# -*- encoding: utf-8 -*- # -*- encoding: utf-8 -*-
from typing import List from typing import List
from torch.nn import Module from torch.nn import Module
from torch.nn.modules.loss import _Loss from torch.nn.modules.loss import _Loss
@ -9,6 +8,7 @@ from torch.optim import Optimizer
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from torch import Tensor from torch import Tensor
from colossalai.engine.ophooks import register_ophooks_recursively, BaseOpHook
class Engine: class Engine:
@ -29,15 +29,14 @@ class Engine:
:param verbose: whether to display log info :param verbose: whether to display log info
:type verbose: bool :type verbose: bool
""" """
def __init__(self, def __init__(self,
model: Module, model: Module,
optimizer: Optimizer, optimizer: Optimizer,
criterion: _Loss, criterion: _Loss,
gradient_handlers: List = None, gradient_handlers: List = None,
clip_grad_norm: float = 0.0, clip_grad_norm: float = 0.0,
verbose: bool = True ophook_list: List[BaseOpHook] = [],
): verbose: bool = True):
self._model = model self._model = model
self._optimizer = optimizer self._optimizer = optimizer
self._criterion = criterion self._criterion = criterion
@ -54,6 +53,9 @@ class Engine:
else: else:
self._gradient_handlers = [] self._gradient_handlers = []
self._ophook_list = ophook_list
register_ophooks_recursively(self._model, self._ophook_list)
@property @property
def model(self): def model(self):
"""Model attached to the engine""" """Model attached to the engine"""
@ -87,7 +89,10 @@ class Engine:
:param loss: Loss value computed by a loss function :param loss: Loss value computed by a loss function
:type loss: :class:`torch.Tensor` :type loss: :class:`torch.Tensor`
""" """
return self.optimizer.backward(loss) ret = self.optimizer.backward(loss)
for ophook in self._ophook_list:
ophook.post_iter()
return ret
def backward_by_grad(self, tensor, grad): def backward_by_grad(self, tensor, grad):
"""Start backward propagation given the gradient of the output tensor """Start backward propagation given the gradient of the output tensor
@ -97,7 +102,10 @@ class Engine:
:param grad: Gradient passed back to the output :param grad: Gradient passed back to the output
:type grad: :class:`torch.Tensor` :type grad: :class:`torch.Tensor`
""" """
return self.optimizer.backward_by_grad(tensor, grad) ret = self.optimizer.backward_by_grad(tensor, grad)
for ophook in self._ophook_list:
ophook.post_iter()
return ret
def calc_loss(self, *args, **kwargs): def calc_loss(self, *args, **kwargs):
"""Compute the loss value """Compute the loss value

View File

@ -0,0 +1,115 @@
from ._base_ophook import BaseOpHook
from ._memtracer_ophook import MemTracerOpHook
import torch
from typing import List
all = ["BaseOpHook", "MemTracerOpHook", "register_ophooks_recursively"]
# apply torch.autograd.Function that calls a backward_function to tensors in output
def _apply_to_tensors_only(module, functional, backward_function, outputs):
if type(outputs) is tuple:
touched_outputs = []
for output in outputs:
touched_output = _apply_to_tensors_only(module, functional,
backward_function, output)
touched_outputs.append(touched_output)
return tuple(touched_outputs)
elif type(outputs) is torch.Tensor:
return functional.apply(module, backward_function, outputs)
else:
return outputs
class PreBackwardFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, module, pre_backward_function, outputs):
ctx.module = module
ctx.pre_backward_function = pre_backward_function
module.applied_pre_backward = False
outputs = outputs.detach()
return outputs
@staticmethod
def backward(ctx, *args):
ctx.pre_backward_function(ctx.module)
return (None, None) + args
class PostBackwardFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, module, pre_backward_function, output):
ctx.module = module
output = output.detach()
ctx.pre_backward_function = pre_backward_function
return output
@staticmethod
def backward(ctx, *args):
"""
Args:
activation_grad of the next layer.
Returns:
grad of the input activation.
"""
ctx.pre_backward_function(ctx.module)
return (None, None) + args
def register_ophooks_recursively(module: torch.nn.Module,
ophook_list: List[BaseOpHook] = None,
name: str = ""):
r"""Recursilvely register pre/post hooks for all submodules in the module in FWD and BWD."""
assert isinstance(module, torch.nn.Module)
has_children = False
for child_name, child in module.named_children():
register_ophooks_recursively(child, ophook_list, name + child_name)
has_children = True
# Early return on modules with no parameters or buffers that
# are not in their children.
if (len(list(module.named_parameters(recurse=False))) == 0
and len(list(module.named_buffers(recurse=False))) == 0):
return
# return if the module has not childern.
if has_children:
return
if ophook_list is not None:
for hook in ophook_list:
assert (isinstance(hook, BaseOpHook))
def _pre_forward_module_hook(submodule, *args):
for hook in ophook_list:
assert isinstance(submodule, torch.nn.Module)
hook.pre_fwd_exec(submodule, *args)
def _post_forward_module_hook(submodule, *args):
for hook in ophook_list:
assert isinstance(submodule, torch.nn.Module)
hook.post_fwd_exec(submodule, *args)
def _pre_backward_module_hook(submodule, inputs, output):
def _run_before_backward_function(submodule):
for hook in ophook_list:
assert isinstance(submodule, torch.nn.Module)
hook.pre_bwd_exec(submodule, inputs, output)
return _apply_to_tensors_only(submodule, PreBackwardFunction,
_run_before_backward_function, output)
def _post_backward_module_hook(submodule, inputs):
def _run_after_backward_function(submodule):
for hook in ophook_list:
assert isinstance(submodule, torch.nn.Module)
hook.post_bwd_exec(submodule, inputs)
return _apply_to_tensors_only(submodule, PostBackwardFunction,
_run_after_backward_function, inputs)
module.register_forward_pre_hook(_pre_forward_module_hook)
module.register_forward_hook(_post_forward_module_hook)
module.register_forward_hook(_pre_backward_module_hook)
module.register_forward_pre_hook(_post_backward_module_hook)

View File

@ -0,0 +1,29 @@
from abc import ABC, abstractmethod
import torch
class BaseOpHook(ABC):
"""This class allows users to add customized operations
before and after the execution of a PyTorch submodule"""
def __init__(self):
pass
@abstractmethod
def pre_fwd_exec(self, module: torch.nn.Module, *args):
pass
@abstractmethod
def post_fwd_exec(self, module: torch.nn.Module, *args):
pass
@abstractmethod
def pre_bwd_exec(self, module: torch.nn.Module, input, output):
pass
@abstractmethod
def post_bwd_exec(self, module: torch.nn.Module, input):
pass
@abstractmethod
def post_iter(self):
pass

View File

@ -0,0 +1,131 @@
import torch
from . import BaseOpHook
from concurrent.futures import ThreadPoolExecutor
from colossalai.registry import OPHOOKS
from colossalai.logging import get_dist_logger
from time import sleep, time
import psutil
import pickle
def get_cuda_memory_used(device):
"""
Get the free memory info of device.
Notice that for CPU, this function will return 1/N of the total free memory,
where N is the world size.
"""
ret = torch.cuda.memory_allocated()
# get the peak memory to report correct data, so reset the counter for the next call
if hasattr(torch.cuda, "reset_peak_memory_stats"): # pytorch 1.4+
torch.cuda.reset_peak_memory_stats()
return ret
class AsyncMemoryMonitor:
def __init__(self, power=10):
"""
An Async Mem Monitor runing during computing.
Sampling GPU memory usage of the current GPU dev
at interval of 1/(10**power) sec.
"""
self.keep_measuring = False
self.executor = ThreadPoolExecutor(max_workers=1)
self.monitor_thread = None
self.interval = 1 / (10**power)
self.time_stamps = []
self.mem_stats = []
def set_interval(self, power: int):
self.interval = 1 / (10**power)
def is_measuring(self):
return self.keep_measuring
def start(self):
self.keep_measuring = True
self.monitor_thread = self.executor.submit(self._measure_usage)
def finish(self):
if self.keep_measuring is False:
return 0
self.keep_measuring = False
max_usage = self.monitor_thread.result()
self.monitor_thread = None
self.time_stamps.append(time())
self.mem_stats.append(max_usage)
return max_usage
def _measure_usage(self):
max_usage = 0
dev = torch.device(f"cuda:{torch.cuda.current_device()}")
while self.keep_measuring:
max_usage = max(
max_usage,
get_cuda_memory_used(dev),
)
sleep(self.interval)
return max_usage
def state_dict(self):
return {
"time_stamps": self.time_stamps,
"mem_stats": self.mem_stats,
}
def save(self, filename):
with open(filename, "wb") as f:
pickle.dump(self.state_dict(), f)
@OPHOOKS.register_module
class MemTracerOpHook(BaseOpHook):
def __init__(self, niter=5):
super().__init__()
self.async_mem_monitor = AsyncMemoryMonitor()
self._niter = niter
self._curiter = 0
self._logger = get_dist_logger()
def _isvalid(self, module):
return module.training and self._curiter < self._niter
def niter(self):
return self._niter
def pre_fwd_exec(self, module: torch.nn.Module, *args):
if self._isvalid(module):
self.async_mem_monitor.finish()
self.async_mem_monitor.start()
self._logger.debug(f'FWD PRE {module.__class__.__name__}')
def post_fwd_exec(self, module: torch.nn.Module, *args):
if self._isvalid(module):
self.async_mem_monitor.finish()
self._logger.debug(f'FWD POST {module.__class__.__name__}')
def pre_bwd_exec(self, module: torch.nn.Module, input, output):
assert isinstance(module, torch.nn.Module)
if self._isvalid(module):
self.async_mem_monitor.finish()
self.async_mem_monitor.start()
self._logger.debug(f'BWD PRE {module.__class__.__name__}')
def post_bwd_exec(self, module: torch.nn.Module, input):
assert isinstance(module, torch.nn.Module)
if self._isvalid(module):
self.async_mem_monitor.finish()
self._logger.debug(f'BWD POST {module.__class__.__name__}')
def pre_iter(self):
pass
def post_iter(self):
self.async_mem_monitor.finish()
if self._curiter == self._niter:
self._logger.info(
f'dump a memory statistics as pickle to ./memstats.pkl')
self.save_results("memstats.pkl")
self._curiter += 1
def save_results(self, filename):
self.async_mem_monitor.save(filename)

View File

@ -7,16 +7,17 @@ from torchvision import transforms
from .registry import Registry from .registry import Registry
LAYERS = Registry('layers', third_party_library=[nn]) LAYERS = Registry("layers", third_party_library=[nn])
LOSSES = Registry('losses') LOSSES = Registry("losses")
MODELS = Registry('models', third_party_library=[tv_models]) MODELS = Registry("models", third_party_library=[tv_models])
OPTIMIZERS = Registry('optimizers', third_party_library=[optim, dist_optim]) OPTIMIZERS = Registry("optimizers", third_party_library=[optim, dist_optim])
DATASETS = Registry('datasets', third_party_library=[tv_datasets]) DATASETS = Registry("datasets", third_party_library=[tv_datasets])
DIST_GROUP_INITIALIZER = Registry('dist_group_initializer') DIST_GROUP_INITIALIZER = Registry("dist_group_initializer")
GRADIENT_HANDLER = Registry('gradient_handler') GRADIENT_HANDLER = Registry("gradient_handler")
LOSSES = Registry('losses', third_party_library=[nn]) LOSSES = Registry("losses", third_party_library=[nn])
HOOKS = Registry('hooks') HOOKS = Registry("hooks")
TRANSFORMS = Registry('transforms', third_party_library=[transforms]) TRANSFORMS = Registry("transforms", third_party_library=[transforms])
DATA_SAMPLERS = Registry('data_samplers') DATA_SAMPLERS = Registry("data_samplers")
LR_SCHEDULERS = Registry('lr_schedulers') LR_SCHEDULERS = Registry("lr_schedulers")
SCHEDULE = Registry('schedules') SCHEDULE = Registry("schedules")
OPHOOKS = Registry("ophooks")

View File

@ -1,8 +1,4 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from typing import Union, List from typing import Union, List
from colossalai import engine
from colossalai.context.parallel_mode import ParallelMode from colossalai.context.parallel_mode import ParallelMode
import torch import torch
@ -11,12 +7,13 @@ from torch.utils.data import DataLoader
from tqdm import tqdm from tqdm import tqdm
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.engine import Engine from colossalai.engine import Engine
from colossalai.engine.schedule import NonPipelineSchedule, BaseSchedule from colossalai.engine.schedule import NonPipelineSchedule, BaseSchedule
from colossalai.logging import DistributedLogger from colossalai.logging import DistributedLogger
from colossalai.utils import MultiTimer from colossalai.utils import MultiTimer
from colossalai.utils import is_dp_rank_0, is_tp_rank_0, is_no_pp_or_last_stage from colossalai.utils import is_dp_rank_0, is_tp_rank_0, is_no_pp_or_last_stage
from .hooks import BaseHook from colossalai.trainer.hooks import BaseHook
class Trainer: class Trainer:
@ -33,12 +30,13 @@ class Trainer:
:param logger: Logger used to record the whole training :param logger: Logger used to record the whole training
:type logger: :class:`colossalai.logging.DistributedLogger`, optional :type logger: :class:`colossalai.logging.DistributedLogger`, optional
""" """
def __init__(
def __init__(self, self,
engine: Engine, engine: Engine,
schedule: BaseSchedule = None, schedule: BaseSchedule = None,
timer: MultiTimer = None, timer: MultiTimer = None,
logger: DistributedLogger = None): logger: DistributedLogger = None,
):
# training-ralated params # training-ralated params
self._engine = engine self._engine = engine
self._max_epochs = 0 self._max_epochs = 0
@ -63,29 +61,28 @@ class Trainer:
# set schedule which specifies the training iteration for the engine # set schedule which specifies the training iteration for the engine
if schedule is None: if schedule is None:
schedule = NonPipelineSchedule() schedule = NonPipelineSchedule()
if gpc.is_initialized(ParallelMode.PIPELINE) and gpc.get_world_size(ParallelMode.PIPELINE) > 1: if (gpc.is_initialized(ParallelMode.PIPELINE)
assert not isinstance(schedule, NonPipelineSchedule), \ and gpc.get_world_size(ParallelMode.PIPELINE) > 1):
'NonPipelineSchedule cannot be used for pipeline parallel training, please use PipelineSchedule instead.' assert not isinstance(
schedule, NonPipelineSchedule
), "NonPipelineSchedule cannot be used for pipeline parallel training, please use PipelineSchedule instead."
self._schedule = schedule self._schedule = schedule
self._schedule.pre_processing(engine) self._schedule.pre_processing(engine)
@property @property
def cur_epoch(self): def cur_epoch(self):
"""Returns the index of the current epoch. """Returns the index of the current epoch."""
"""
return self._cur_epoch return self._cur_epoch
@cur_epoch.setter @cur_epoch.setter
def cur_epoch(self, epoch: int): def cur_epoch(self, epoch: int):
"""Set how many epochs have been processed. """Set how many epochs have been processed."""
"""
# allow setter for training resumption # allow setter for training resumption
self._cur_epoch = epoch self._cur_epoch = epoch
@property @property
def cur_step(self): def cur_step(self):
"""Returns how many iteration steps have been processed. """Returns how many iteration steps have been processed."""
"""
return self._cur_step return self._cur_step
@property @property
@ -131,8 +128,7 @@ class Trainer:
getattr(self._timer, action)(item, *args, **kwargs) getattr(self._timer, action)(item, *args, **kwargs)
def _reset_states(self) -> None: def _reset_states(self) -> None:
"""Clear trainer states """Clear trainer states"""
"""
self.states = dict() self.states = dict()
def _call_hooks(self, func, output=None): def _call_hooks(self, func, output=None):
@ -152,99 +148,122 @@ class Trainer:
@staticmethod @staticmethod
def _should_display_progress(display_progress: bool): def _should_display_progress(display_progress: bool):
""" Only display progress on DP rank 0, TP rank 0 and PP last rank """Only display progress on DP rank 0, TP rank 0 and PP last rank"""
""" return (display_progress and is_dp_rank_0() and is_tp_rank_0()
return display_progress and is_dp_rank_0() and is_tp_rank_0() and is_no_pp_or_last_stage() and is_no_pp_or_last_stage())
def _train_epoch(self, def _train_epoch(
self,
train_dataloader: DataLoader, train_dataloader: DataLoader,
epoch: int = None, epoch: int = None,
display_progress: bool = False, display_progress: bool = False,
return_output_label: bool = True): return_output_label: bool = True,
):
# set training state # set training state
self._engine.train() self._engine.train()
data_iter = iter(train_dataloader) data_iter = iter(train_dataloader)
progress = range(self._steps_per_epoch) progress = range(self._steps_per_epoch)
if display_progress: if display_progress:
if epoch is None: if epoch is None:
progress = tqdm(progress, desc='[Train]') progress = tqdm(progress, desc="[Train]")
else: else:
progress = tqdm(progress, desc=f'[Epoch {epoch} / Train]') progress = tqdm(progress, desc=f"[Epoch {epoch} / Train]")
self._call_hooks('before_train_epoch') self._call_hooks("before_train_epoch")
self._call_timer(action='start', item='Train-epoch') self._call_timer(action="start", item="Train-epoch")
for i in progress: for i in progress:
self._call_hooks('before_train_iter') self._call_hooks("before_train_iter")
self._call_timer(action='start', item='Train-step') self._call_timer(action="start", item="Train-step")
# run 1 training step # run 1 training step
self.engine.zero_grad() self.engine.zero_grad()
logits, label, loss = self.schedule.forward_backward_step( logits, label, loss = self.schedule.forward_backward_step(
self.engine, data_iter, forward_only=False, return_loss=True, return_output_label=return_output_label) self.engine,
data_iter,
forward_only=False,
return_loss=True,
return_output_label=return_output_label,
)
self.engine.step() self.engine.step()
self._call_timer(action='stop', item='Train-step', keep_in_history=True) self._call_timer(action="stop",
self._call_hooks('after_train_iter', output=(logits, label, loss)) item="Train-step",
keep_in_history=True)
self._call_hooks("after_train_iter", output=(logits, label, loss))
self._cur_step += 1 self._cur_step += 1
if display_progress: if display_progress:
if 'step_metrics' in self.states: if "step_metrics" in self.states:
progress.set_postfix(**self.states['step_metrics']) progress.set_postfix(**self.states["step_metrics"])
# stop when max iter is reached # stop when max iter is reached
if self._exceed_max_step(): if self._exceed_max_step():
break break
self._call_timer(action='stop', item='Train-epoch', keep_in_history=True) self._call_timer(action="stop",
self._call_hooks('after_train_epoch') item="Train-epoch",
self._call_timer(action='reset', item='Train-epoch') keep_in_history=True)
self._call_hooks("after_train_epoch")
self._call_timer(action="reset", item="Train-epoch")
def _eval(self, def _eval(
self,
test_dataloader: DataLoader, test_dataloader: DataLoader,
epoch: int = None, epoch: int = None,
display_progress: bool = False, display_progress: bool = False,
return_output_label: bool = True): return_output_label: bool = True,
):
# switch engine status # switch engine status
self._engine.eval() self._engine.eval()
data_iter = iter(test_dataloader) data_iter = iter(test_dataloader)
num_steps = len(test_dataloader) num_steps = len(test_dataloader)
self._call_hooks('before_test') self._call_hooks("before_test")
# prepare progress bar # prepare progress bar
progress = range(num_steps) progress = range(num_steps)
if display_progress: if display_progress:
desc = 'Evaluation' desc = "Evaluation"
if epoch is not None: if epoch is not None:
desc = '[Epoch %d / Test]' % epoch desc = "[Epoch %d / Test]" % epoch
progress = tqdm(progress, desc=desc) progress = tqdm(progress, desc=desc)
self._call_hooks('before_test_epoch') self._call_hooks("before_test_epoch")
self._call_timer(action='start', item='Test-epoch') self._call_timer(action="start", item="Test-epoch")
with torch.no_grad(): with torch.no_grad():
for _ in progress: for _ in progress:
self._call_hooks('before_test_iter') self._call_hooks("before_test_iter")
self._call_timer(action='start', item='Test-step') self._call_timer(action="start", item="Test-step")
logits, label, loss = self.schedule.forward_backward_step( logits, label, loss = self.schedule.forward_backward_step(
self.engine, data_iter, forward_only=True, return_loss=True, return_output_label=return_output_label) self.engine,
self._call_timer(action='stop', item='Test-step', keep_in_history=True) data_iter,
self._call_hooks('after_test_iter', forward_only=True,
return_loss=True,
return_output_label=return_output_label,
)
self._call_timer(action="stop",
item="Test-step",
keep_in_history=True)
self._call_hooks("after_test_iter",
output=(logits, label, loss)) output=(logits, label, loss))
if display_progress: if display_progress:
if 'step_metrics' in self.states: if "step_metrics" in self.states:
progress.set_postfix(**self.states['step_metrics']) progress.set_postfix(**self.states["step_metrics"])
self._call_timer(action='stop', item='Test-epoch', keep_in_history=True) self._call_timer(action="stop",
self._call_hooks('after_test_epoch') item="Test-epoch",
self._call_hooks('after_test') keep_in_history=True)
self._call_timer(action='reset', item='Test-step') self._call_hooks("after_test_epoch")
self._call_timer(action='reset', item='Test-epoch') self._call_hooks("after_test")
self._call_timer(action="reset", item="Test-step")
self._call_timer(action="reset", item="Test-epoch")
def _exceed_max_step(self): def _exceed_max_step(self):
return self._max_steps is not None and self._cur_step >= self._max_steps return self._max_steps is not None and self._cur_step >= self._max_steps
def fit(self, def fit(
self,
train_dataloader: DataLoader, train_dataloader: DataLoader,
epochs: int, epochs: int,
max_steps: int = None, max_steps: int = None,
@ -290,7 +309,9 @@ class Trainer:
# reset hooks # reset hooks
self._reset_states() self._reset_states()
if hooks is not None: if hooks is not None:
assert isinstance(hooks, list), f'expected argument hooks be to list, but got {type(hooks)}' assert isinstance(
hooks, list
), f"expected argument hooks be to list, but got {type(hooks)}"
else: else:
hooks = [] hooks = []
self.hooks = hooks self.hooks = hooks
@ -298,13 +319,16 @@ class Trainer:
if self._verbose: if self._verbose:
for hook in self.hooks: for hook in self.hooks:
self._logger.info( self._logger.info(
f'Using {hook.__class__.__name__} for training, priority = {hook.priority}', ranks=[0]) f"Using {hook.__class__.__name__} for training, priority = {hook.priority}",
self._logger.info("Lower value means higher priority for calling hook function", ranks=[0]) ranks=[0],
self._call_hooks('after_hook_is_attached') )
self._logger.info(
"Lower value means higher priority for calling hook function",
ranks=[0])
self._call_hooks("after_hook_is_attached")
# start train
self._engine.train() self._engine.train()
self._call_hooks('before_train') self._call_hooks("before_train")
# recover step value if resuming training # recover step value if resuming training
last_epoch = self._cur_epoch last_epoch = self._cur_epoch
@ -317,15 +341,16 @@ class Trainer:
train_dataloader=train_dataloader, train_dataloader=train_dataloader,
epoch=epoch, epoch=epoch,
display_progress=display_progress, display_progress=display_progress,
return_output_label=return_output_label return_output_label=return_output_label,
) )
# start eval # start eval
if should_test and epoch % test_interval == 0: if should_test and epoch % test_interval == 0:
self._eval(test_dataloader=test_dataloader, self._eval(
test_dataloader=test_dataloader,
display_progress=display_progress, display_progress=display_progress,
epoch=epoch, epoch=epoch,
return_output_label=return_output_label return_output_label=return_output_label,
) )
self._cur_epoch += 1 self._cur_epoch += 1
@ -334,16 +359,19 @@ class Trainer:
if self._exceed_max_step(): if self._exceed_max_step():
self._logger.info( self._logger.info(
f"Max number of steps {max_steps} has been reached, training is stopped automatically", f"Max number of steps {max_steps} has been reached, training is stopped automatically",
ranks=[0]) ranks=[0],
)
break break
self._call_hooks('after_train') self._call_hooks("after_train")
self._call_timer('reset', 'Train-epoch') self._call_timer("reset", "Train-epoch")
def evaluate(self, def evaluate(
self,
test_dataloader: DataLoader, test_dataloader: DataLoader,
hooks: List[BaseHook] = None, hooks: List[BaseHook] = None,
display_progress: bool = False, display_progress: bool = False,
return_output_label: bool = True): return_output_label: bool = True,
):
"""Evaluates the model with testing data. """Evaluates the model with testing data.
:param test_dataloader: DataLoader in testing :param test_dataloader: DataLoader in testing
@ -362,7 +390,9 @@ class Trainer:
# reset hooks # reset hooks
self._reset_states() self._reset_states()
if hooks is not None: if hooks is not None:
assert isinstance(hooks, list), f'expected argument hooks be to list, but got {type(hooks)}' assert isinstance(
hooks, list
), f"expected argument hooks be to list, but got {type(hooks)}"
else: else:
hooks = [] hooks = []
self.hooks = hooks self.hooks = hooks
@ -370,14 +400,19 @@ class Trainer:
if self._verbose: if self._verbose:
for hook in self.hooks: for hook in self.hooks:
self._logger.info( self._logger.info(
f'Using {hook.__class__.__name__} for training, priority = {hook.priority}', ranks=[0]) f"Using {hook.__class__.__name__} for training, priority = {hook.priority}",
self._logger.info("Lower value means higher priority for calling hook function", ranks=[0]) ranks=[0],
self._call_hooks('after_hook_is_attached') )
self._logger.info(
"Lower value means higher priority for calling hook function",
ranks=[0])
self._call_hooks("after_hook_is_attached")
# eval # eval
self._eval(test_dataloader=test_dataloader, self._eval(
test_dataloader=test_dataloader,
display_progress=display_progress, display_progress=display_progress,
return_output_label=return_output_label return_output_label=return_output_label,
) )
def predict(self, data: Union[Tensor, List[Tensor]]): def predict(self, data: Union[Tensor, List[Tensor]]):
@ -399,6 +434,8 @@ class Trainer:
# for compatibility with schedule # for compatibility with schedule
simple_dataloader = [(data, None)] simple_dataloader = [(data, None)]
data_iter = iter(simple_dataloader) data_iter = iter(simple_dataloader)
output, _, _ = self.schedule.forward_backward_step( output, _, _ = self.schedule.forward_backward_step(self.engine,
self.engine, data_iter, forward_only=True, return_loss=False) data_iter,
forward_only=True,
return_loss=False)
return output return output

View File

@ -6,6 +6,7 @@ from pathlib import Path
import pytest import pytest
from colossalai.context.config import Config from colossalai.context.config import Config
from colossalai.builder import build_ophooks
@pytest.mark.cpu @pytest.mark.cpu
@ -17,3 +18,10 @@ def test_load_config():
assert config.train_data.dataset, 'cannot access grandchild attribute' assert config.train_data.dataset, 'cannot access grandchild attribute'
assert isinstance(config.train_data.dataset.transform_pipeline[0], dict), \ assert isinstance(config.train_data.dataset.transform_pipeline[0], dict), \
f'expected attribute transform_pipeline elements to be a dict, but found {type(config.train_data.dataset.transform_pipeline)}' f'expected attribute transform_pipeline elements to be a dict, but found {type(config.train_data.dataset.transform_pipeline)}'
@pytest.mark.cpu
def test_load_ophooks():
dict = {'type': 'MemTracerOpHook', 'niter': 2}
ophook = build_ophooks(dict)
assert ophook.niter() == 2