mirror of https://github.com/hpcaitech/ColossalAI
add pytorch hooks (#179)
* add pytorch hooks fix #175 * remove licenses in src code * add gpu memory tracer * replacing print with logger in ophooks.pull/200/head
parent
708404d5f8
commit
569357fea0
|
@ -1,10 +1,12 @@
|
|||
from .builder import (build_schedule, build_lr_scheduler, build_model, build_optimizer, build_layer,
|
||||
build_loss, build_hooks, build_dataset, build_transform, build_data_sampler,
|
||||
build_gradient_handler)
|
||||
from .builder import (build_schedule, build_lr_scheduler, build_model,
|
||||
build_optimizer, build_layer, build_loss, build_hooks,
|
||||
build_dataset, build_transform, build_data_sampler,
|
||||
build_gradient_handler, build_ophooks)
|
||||
from .pipeline import build_pipeline_model, build_pipeline_model_from_cfg
|
||||
|
||||
__all__ = [
|
||||
'build_schedule', 'build_lr_scheduler', 'build_model', 'build_optimizer',
|
||||
'build_layer', 'build_loss', 'build_hooks', 'build_dataset', 'build_transform', 'build_data_sampler',
|
||||
'build_gradient_handler', 'build_pipeline_model', 'build_pipeline_model_from_cfg'
|
||||
'build_layer', 'build_loss', 'build_hooks', 'build_dataset',
|
||||
'build_transform', 'build_data_sampler', 'build_gradient_handler',
|
||||
'build_pipeline_model', 'build_pipeline_model_from_cfg', 'build_ophooks'
|
||||
]
|
||||
|
|
|
@ -27,7 +27,7 @@ def build_from_registry(config, registry: Registry):
|
|||
"""Returns an object constructed from `config`, the type of the object
|
||||
is specified by `registry`.
|
||||
|
||||
:param config: A python dict or a :class:`colossalai.context.Config` object
|
||||
:param config: A python dict or a :class:`colossalai.context.Config` object
|
||||
containing information used in the construction of the return object
|
||||
:type config: dict or :class:`colossalai.context.colossalai.context.Config`
|
||||
:param registry: A registry specifying the type of the return object
|
||||
|
@ -50,7 +50,8 @@ def build_from_registry(config, registry: Registry):
|
|||
obj = registry.get_module(mod_type)(**config_)
|
||||
except Exception as e:
|
||||
print(
|
||||
f'An error occurred when building {mod_type} from registry {registry.name}', flush=True)
|
||||
f'An error occurred when building {mod_type} from registry {registry.name}',
|
||||
flush=True)
|
||||
raise e
|
||||
|
||||
return obj
|
||||
|
@ -69,7 +70,7 @@ def build_layer(config):
|
|||
|
||||
|
||||
def build_loss(config):
|
||||
"""Returns a loss function object of :class:`torch.autograd.Function` constructed
|
||||
"""Returns a loss function object of :class:`torch.autograd.Function` constructed
|
||||
from `config`.
|
||||
|
||||
:param config: A python dict or a :class:`colossalai.context.Config` object
|
||||
|
@ -94,7 +95,7 @@ def build_model(config):
|
|||
|
||||
|
||||
def build_dataset(config):
|
||||
"""Returns a dataset object of :class:`torch.utils.data.Dataset` constructed
|
||||
"""Returns a dataset object of :class:`torch.utils.data.Dataset` constructed
|
||||
from `config`.
|
||||
|
||||
:param config: A python dict or a :class:`colossalai.context.Config` object
|
||||
|
@ -107,13 +108,13 @@ def build_dataset(config):
|
|||
|
||||
|
||||
def build_optimizer(config, model):
|
||||
"""Returns an optimizer object of :class:`torch.optim.Optimizer` constructed from `config`,
|
||||
"""Returns an optimizer object of :class:`torch.optim.Optimizer` constructed from `config`,
|
||||
'model' and 'params'.
|
||||
|
||||
:param config: A python dict or a :class:`colossalai.context.Config` object
|
||||
:param config: A python dict or a :class:`colossalai.context.Config` object
|
||||
containing information used in the construction of the return object
|
||||
:type config: dict or :class:`colossalai.context.Config`
|
||||
:param model: A model containing parameters for the optimizer
|
||||
:param model: A model containing parameters for the optimizer
|
||||
:type model: :class:`nn.Module`
|
||||
:return: An object of :class:`torch.optim.Optimizer`
|
||||
:rtype: :class:`torch.optim.Optimizer`
|
||||
|
@ -159,6 +160,19 @@ def build_hooks(config, trainer):
|
|||
return build_from_registry(config_, HOOKS)
|
||||
|
||||
|
||||
def build_ophooks(config):
|
||||
"""Returns a hook object of :class:`BaseOpHook` constructed from `config`.
|
||||
|
||||
:param config: A python dict or a :class:`colossalai.context.Config` object
|
||||
containing information used in the construction of the return object
|
||||
:type config: dict or :class:`colossalai.context.Config`
|
||||
:return: An object of :class:`colossalai.trainer.hooks.BaseOpHook`
|
||||
:rtype: :class:`colossalai.trainer.hooks.BaseOpHook`
|
||||
"""
|
||||
config_ = config.copy()
|
||||
return build_from_registry(config_, OPHOOKS)
|
||||
|
||||
|
||||
def build_transform(config):
|
||||
"""Returns a transformation object of :class:`torchvision.transforms` constructed
|
||||
from `config`.
|
||||
|
@ -191,10 +205,10 @@ def build_data_sampler(config, dataset):
|
|||
|
||||
|
||||
def build_lr_scheduler(config, optimizer):
|
||||
"""Returns a learning rate scheduler object of :class:`torch.optim.lr_scheduler`
|
||||
"""Returns a learning rate scheduler object of :class:`torch.optim.lr_scheduler`
|
||||
constructed from `config`, `optimizer`, `total_steps` and `num_steps_per_epoch`.
|
||||
|
||||
:param config: A python dict or a :class:`colossalai.context.Config` object
|
||||
:param config: A python dict or a :class:`colossalai.context.Config` object
|
||||
containing information used in the construction of the return object
|
||||
:type config: dict or :class:`colossalai.context.Config`
|
||||
:param optimizer: An optimizer object containing parameters for the learning rate
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
|
||||
from typing import List
|
||||
from torch.nn import Module
|
||||
from torch.nn.modules.loss import _Loss
|
||||
|
@ -9,10 +8,11 @@ from torch.optim import Optimizer
|
|||
|
||||
from colossalai.logging import get_dist_logger
|
||||
from torch import Tensor
|
||||
from colossalai.engine.ophooks import register_ophooks_recursively, BaseOpHook
|
||||
|
||||
|
||||
class Engine:
|
||||
"""Basic engine class for training and evaluation. It runs a specific process method
|
||||
"""Basic engine class for training and evaluation. It runs a specific process method
|
||||
:meth:`step` which is based on the given :attr:`schedule` over each batch of a dataset.
|
||||
It controls a iteration in training.
|
||||
|
||||
|
@ -29,15 +29,14 @@ class Engine:
|
|||
:param verbose: whether to display log info
|
||||
:type verbose: bool
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
model: Module,
|
||||
optimizer: Optimizer,
|
||||
criterion: _Loss,
|
||||
gradient_handlers: List = None,
|
||||
clip_grad_norm: float = 0.0,
|
||||
verbose: bool = True
|
||||
):
|
||||
ophook_list: List[BaseOpHook] = [],
|
||||
verbose: bool = True):
|
||||
self._model = model
|
||||
self._optimizer = optimizer
|
||||
self._criterion = criterion
|
||||
|
@ -54,6 +53,9 @@ class Engine:
|
|||
else:
|
||||
self._gradient_handlers = []
|
||||
|
||||
self._ophook_list = ophook_list
|
||||
register_ophooks_recursively(self._model, self._ophook_list)
|
||||
|
||||
@property
|
||||
def model(self):
|
||||
"""Model attached to the engine"""
|
||||
|
@ -87,7 +89,10 @@ class Engine:
|
|||
:param loss: Loss value computed by a loss function
|
||||
:type loss: :class:`torch.Tensor`
|
||||
"""
|
||||
return self.optimizer.backward(loss)
|
||||
ret = self.optimizer.backward(loss)
|
||||
for ophook in self._ophook_list:
|
||||
ophook.post_iter()
|
||||
return ret
|
||||
|
||||
def backward_by_grad(self, tensor, grad):
|
||||
"""Start backward propagation given the gradient of the output tensor
|
||||
|
@ -97,7 +102,10 @@ class Engine:
|
|||
:param grad: Gradient passed back to the output
|
||||
:type grad: :class:`torch.Tensor`
|
||||
"""
|
||||
return self.optimizer.backward_by_grad(tensor, grad)
|
||||
ret = self.optimizer.backward_by_grad(tensor, grad)
|
||||
for ophook in self._ophook_list:
|
||||
ophook.post_iter()
|
||||
return ret
|
||||
|
||||
def calc_loss(self, *args, **kwargs):
|
||||
"""Compute the loss value
|
||||
|
|
|
@ -0,0 +1,115 @@
|
|||
from ._base_ophook import BaseOpHook
|
||||
from ._memtracer_ophook import MemTracerOpHook
|
||||
import torch
|
||||
from typing import List
|
||||
|
||||
all = ["BaseOpHook", "MemTracerOpHook", "register_ophooks_recursively"]
|
||||
|
||||
|
||||
# apply torch.autograd.Function that calls a backward_function to tensors in output
|
||||
def _apply_to_tensors_only(module, functional, backward_function, outputs):
|
||||
if type(outputs) is tuple:
|
||||
touched_outputs = []
|
||||
for output in outputs:
|
||||
touched_output = _apply_to_tensors_only(module, functional,
|
||||
backward_function, output)
|
||||
touched_outputs.append(touched_output)
|
||||
return tuple(touched_outputs)
|
||||
elif type(outputs) is torch.Tensor:
|
||||
return functional.apply(module, backward_function, outputs)
|
||||
else:
|
||||
return outputs
|
||||
|
||||
|
||||
class PreBackwardFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, module, pre_backward_function, outputs):
|
||||
ctx.module = module
|
||||
ctx.pre_backward_function = pre_backward_function
|
||||
module.applied_pre_backward = False
|
||||
outputs = outputs.detach()
|
||||
return outputs
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, *args):
|
||||
ctx.pre_backward_function(ctx.module)
|
||||
return (None, None) + args
|
||||
|
||||
|
||||
class PostBackwardFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, module, pre_backward_function, output):
|
||||
ctx.module = module
|
||||
output = output.detach()
|
||||
ctx.pre_backward_function = pre_backward_function
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, *args):
|
||||
"""
|
||||
Args:
|
||||
activation_grad of the next layer.
|
||||
Returns:
|
||||
grad of the input activation.
|
||||
"""
|
||||
ctx.pre_backward_function(ctx.module)
|
||||
return (None, None) + args
|
||||
|
||||
|
||||
def register_ophooks_recursively(module: torch.nn.Module,
|
||||
ophook_list: List[BaseOpHook] = None,
|
||||
name: str = ""):
|
||||
r"""Recursilvely register pre/post hooks for all submodules in the module in FWD and BWD."""
|
||||
assert isinstance(module, torch.nn.Module)
|
||||
has_children = False
|
||||
for child_name, child in module.named_children():
|
||||
register_ophooks_recursively(child, ophook_list, name + child_name)
|
||||
has_children = True
|
||||
|
||||
# Early return on modules with no parameters or buffers that
|
||||
# are not in their children.
|
||||
if (len(list(module.named_parameters(recurse=False))) == 0
|
||||
and len(list(module.named_buffers(recurse=False))) == 0):
|
||||
return
|
||||
|
||||
# return if the module has not childern.
|
||||
if has_children:
|
||||
return
|
||||
|
||||
if ophook_list is not None:
|
||||
for hook in ophook_list:
|
||||
assert (isinstance(hook, BaseOpHook))
|
||||
|
||||
def _pre_forward_module_hook(submodule, *args):
|
||||
for hook in ophook_list:
|
||||
assert isinstance(submodule, torch.nn.Module)
|
||||
hook.pre_fwd_exec(submodule, *args)
|
||||
|
||||
def _post_forward_module_hook(submodule, *args):
|
||||
for hook in ophook_list:
|
||||
assert isinstance(submodule, torch.nn.Module)
|
||||
hook.post_fwd_exec(submodule, *args)
|
||||
|
||||
def _pre_backward_module_hook(submodule, inputs, output):
|
||||
def _run_before_backward_function(submodule):
|
||||
for hook in ophook_list:
|
||||
assert isinstance(submodule, torch.nn.Module)
|
||||
hook.pre_bwd_exec(submodule, inputs, output)
|
||||
|
||||
return _apply_to_tensors_only(submodule, PreBackwardFunction,
|
||||
_run_before_backward_function, output)
|
||||
|
||||
def _post_backward_module_hook(submodule, inputs):
|
||||
def _run_after_backward_function(submodule):
|
||||
for hook in ophook_list:
|
||||
assert isinstance(submodule, torch.nn.Module)
|
||||
hook.post_bwd_exec(submodule, inputs)
|
||||
|
||||
return _apply_to_tensors_only(submodule, PostBackwardFunction,
|
||||
_run_after_backward_function, inputs)
|
||||
|
||||
module.register_forward_pre_hook(_pre_forward_module_hook)
|
||||
module.register_forward_hook(_post_forward_module_hook)
|
||||
|
||||
module.register_forward_hook(_pre_backward_module_hook)
|
||||
module.register_forward_pre_hook(_post_backward_module_hook)
|
|
@ -0,0 +1,29 @@
|
|||
from abc import ABC, abstractmethod
|
||||
import torch
|
||||
|
||||
|
||||
class BaseOpHook(ABC):
|
||||
"""This class allows users to add customized operations
|
||||
before and after the execution of a PyTorch submodule"""
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def pre_fwd_exec(self, module: torch.nn.Module, *args):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def post_fwd_exec(self, module: torch.nn.Module, *args):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def pre_bwd_exec(self, module: torch.nn.Module, input, output):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def post_bwd_exec(self, module: torch.nn.Module, input):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def post_iter(self):
|
||||
pass
|
|
@ -0,0 +1,131 @@
|
|||
import torch
|
||||
from . import BaseOpHook
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from colossalai.registry import OPHOOKS
|
||||
from colossalai.logging import get_dist_logger
|
||||
from time import sleep, time
|
||||
import psutil
|
||||
import pickle
|
||||
|
||||
|
||||
def get_cuda_memory_used(device):
|
||||
"""
|
||||
Get the free memory info of device.
|
||||
Notice that for CPU, this function will return 1/N of the total free memory,
|
||||
where N is the world size.
|
||||
"""
|
||||
ret = torch.cuda.memory_allocated()
|
||||
# get the peak memory to report correct data, so reset the counter for the next call
|
||||
if hasattr(torch.cuda, "reset_peak_memory_stats"): # pytorch 1.4+
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
return ret
|
||||
|
||||
|
||||
class AsyncMemoryMonitor:
|
||||
def __init__(self, power=10):
|
||||
"""
|
||||
An Async Mem Monitor runing during computing.
|
||||
Sampling GPU memory usage of the current GPU dev
|
||||
at interval of 1/(10**power) sec.
|
||||
"""
|
||||
self.keep_measuring = False
|
||||
self.executor = ThreadPoolExecutor(max_workers=1)
|
||||
self.monitor_thread = None
|
||||
self.interval = 1 / (10**power)
|
||||
self.time_stamps = []
|
||||
self.mem_stats = []
|
||||
|
||||
def set_interval(self, power: int):
|
||||
self.interval = 1 / (10**power)
|
||||
|
||||
def is_measuring(self):
|
||||
return self.keep_measuring
|
||||
|
||||
def start(self):
|
||||
self.keep_measuring = True
|
||||
self.monitor_thread = self.executor.submit(self._measure_usage)
|
||||
|
||||
def finish(self):
|
||||
if self.keep_measuring is False:
|
||||
return 0
|
||||
self.keep_measuring = False
|
||||
max_usage = self.monitor_thread.result()
|
||||
self.monitor_thread = None
|
||||
self.time_stamps.append(time())
|
||||
self.mem_stats.append(max_usage)
|
||||
return max_usage
|
||||
|
||||
def _measure_usage(self):
|
||||
max_usage = 0
|
||||
dev = torch.device(f"cuda:{torch.cuda.current_device()}")
|
||||
while self.keep_measuring:
|
||||
max_usage = max(
|
||||
max_usage,
|
||||
get_cuda_memory_used(dev),
|
||||
)
|
||||
sleep(self.interval)
|
||||
return max_usage
|
||||
|
||||
def state_dict(self):
|
||||
return {
|
||||
"time_stamps": self.time_stamps,
|
||||
"mem_stats": self.mem_stats,
|
||||
}
|
||||
|
||||
def save(self, filename):
|
||||
with open(filename, "wb") as f:
|
||||
pickle.dump(self.state_dict(), f)
|
||||
|
||||
|
||||
@OPHOOKS.register_module
|
||||
class MemTracerOpHook(BaseOpHook):
|
||||
def __init__(self, niter=5):
|
||||
super().__init__()
|
||||
self.async_mem_monitor = AsyncMemoryMonitor()
|
||||
self._niter = niter
|
||||
self._curiter = 0
|
||||
self._logger = get_dist_logger()
|
||||
|
||||
def _isvalid(self, module):
|
||||
return module.training and self._curiter < self._niter
|
||||
|
||||
def niter(self):
|
||||
return self._niter
|
||||
|
||||
def pre_fwd_exec(self, module: torch.nn.Module, *args):
|
||||
if self._isvalid(module):
|
||||
self.async_mem_monitor.finish()
|
||||
self.async_mem_monitor.start()
|
||||
self._logger.debug(f'FWD PRE {module.__class__.__name__}')
|
||||
|
||||
def post_fwd_exec(self, module: torch.nn.Module, *args):
|
||||
if self._isvalid(module):
|
||||
self.async_mem_monitor.finish()
|
||||
self._logger.debug(f'FWD POST {module.__class__.__name__}')
|
||||
|
||||
def pre_bwd_exec(self, module: torch.nn.Module, input, output):
|
||||
assert isinstance(module, torch.nn.Module)
|
||||
if self._isvalid(module):
|
||||
self.async_mem_monitor.finish()
|
||||
self.async_mem_monitor.start()
|
||||
self._logger.debug(f'BWD PRE {module.__class__.__name__}')
|
||||
|
||||
def post_bwd_exec(self, module: torch.nn.Module, input):
|
||||
assert isinstance(module, torch.nn.Module)
|
||||
if self._isvalid(module):
|
||||
self.async_mem_monitor.finish()
|
||||
self._logger.debug(f'BWD POST {module.__class__.__name__}')
|
||||
|
||||
def pre_iter(self):
|
||||
pass
|
||||
|
||||
def post_iter(self):
|
||||
self.async_mem_monitor.finish()
|
||||
if self._curiter == self._niter:
|
||||
self._logger.info(
|
||||
f'dump a memory statistics as pickle to ./memstats.pkl')
|
||||
self.save_results("memstats.pkl")
|
||||
self._curiter += 1
|
||||
|
||||
def save_results(self, filename):
|
||||
self.async_mem_monitor.save(filename)
|
|
@ -7,16 +7,17 @@ from torchvision import transforms
|
|||
|
||||
from .registry import Registry
|
||||
|
||||
LAYERS = Registry('layers', third_party_library=[nn])
|
||||
LOSSES = Registry('losses')
|
||||
MODELS = Registry('models', third_party_library=[tv_models])
|
||||
OPTIMIZERS = Registry('optimizers', third_party_library=[optim, dist_optim])
|
||||
DATASETS = Registry('datasets', third_party_library=[tv_datasets])
|
||||
DIST_GROUP_INITIALIZER = Registry('dist_group_initializer')
|
||||
GRADIENT_HANDLER = Registry('gradient_handler')
|
||||
LOSSES = Registry('losses', third_party_library=[nn])
|
||||
HOOKS = Registry('hooks')
|
||||
TRANSFORMS = Registry('transforms', third_party_library=[transforms])
|
||||
DATA_SAMPLERS = Registry('data_samplers')
|
||||
LR_SCHEDULERS = Registry('lr_schedulers')
|
||||
SCHEDULE = Registry('schedules')
|
||||
LAYERS = Registry("layers", third_party_library=[nn])
|
||||
LOSSES = Registry("losses")
|
||||
MODELS = Registry("models", third_party_library=[tv_models])
|
||||
OPTIMIZERS = Registry("optimizers", third_party_library=[optim, dist_optim])
|
||||
DATASETS = Registry("datasets", third_party_library=[tv_datasets])
|
||||
DIST_GROUP_INITIALIZER = Registry("dist_group_initializer")
|
||||
GRADIENT_HANDLER = Registry("gradient_handler")
|
||||
LOSSES = Registry("losses", third_party_library=[nn])
|
||||
HOOKS = Registry("hooks")
|
||||
TRANSFORMS = Registry("transforms", third_party_library=[transforms])
|
||||
DATA_SAMPLERS = Registry("data_samplers")
|
||||
LR_SCHEDULERS = Registry("lr_schedulers")
|
||||
SCHEDULE = Registry("schedules")
|
||||
OPHOOKS = Registry("ophooks")
|
||||
|
|
|
@ -1,8 +1,4 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
from typing import Union, List
|
||||
from colossalai import engine
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
|
||||
import torch
|
||||
|
@ -11,17 +7,18 @@ from torch.utils.data import DataLoader
|
|||
from tqdm import tqdm
|
||||
|
||||
from colossalai.core import global_context as gpc
|
||||
|
||||
from colossalai.engine import Engine
|
||||
from colossalai.engine.schedule import NonPipelineSchedule, BaseSchedule
|
||||
from colossalai.logging import DistributedLogger
|
||||
from colossalai.utils import MultiTimer
|
||||
from colossalai.utils import is_dp_rank_0, is_tp_rank_0, is_no_pp_or_last_stage
|
||||
from .hooks import BaseHook
|
||||
from colossalai.trainer.hooks import BaseHook
|
||||
|
||||
|
||||
class Trainer:
|
||||
"""This a class tending for easy deployments of users' training and evaluation instead of
|
||||
writing their own scripts. It is similar with ``ignite.engine`` and ``keras.engine``, but is
|
||||
"""This a class tending for easy deployments of users' training and evaluation instead of
|
||||
writing their own scripts. It is similar with ``ignite.engine`` and ``keras.engine``, but is
|
||||
called `Trainer`.
|
||||
|
||||
:param engine: Engine responsible for the process function
|
||||
|
@ -33,12 +30,13 @@ class Trainer:
|
|||
:param logger: Logger used to record the whole training
|
||||
:type logger: :class:`colossalai.logging.DistributedLogger`, optional
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
engine: Engine,
|
||||
schedule: BaseSchedule = None,
|
||||
timer: MultiTimer = None,
|
||||
logger: DistributedLogger = None):
|
||||
def __init__(
|
||||
self,
|
||||
engine: Engine,
|
||||
schedule: BaseSchedule = None,
|
||||
timer: MultiTimer = None,
|
||||
logger: DistributedLogger = None,
|
||||
):
|
||||
# training-ralated params
|
||||
self._engine = engine
|
||||
self._max_epochs = 0
|
||||
|
@ -63,29 +61,28 @@ class Trainer:
|
|||
# set schedule which specifies the training iteration for the engine
|
||||
if schedule is None:
|
||||
schedule = NonPipelineSchedule()
|
||||
if gpc.is_initialized(ParallelMode.PIPELINE) and gpc.get_world_size(ParallelMode.PIPELINE) > 1:
|
||||
assert not isinstance(schedule, NonPipelineSchedule), \
|
||||
'NonPipelineSchedule cannot be used for pipeline parallel training, please use PipelineSchedule instead.'
|
||||
if (gpc.is_initialized(ParallelMode.PIPELINE)
|
||||
and gpc.get_world_size(ParallelMode.PIPELINE) > 1):
|
||||
assert not isinstance(
|
||||
schedule, NonPipelineSchedule
|
||||
), "NonPipelineSchedule cannot be used for pipeline parallel training, please use PipelineSchedule instead."
|
||||
self._schedule = schedule
|
||||
self._schedule.pre_processing(engine)
|
||||
|
||||
@property
|
||||
def cur_epoch(self):
|
||||
"""Returns the index of the current epoch.
|
||||
"""
|
||||
"""Returns the index of the current epoch."""
|
||||
return self._cur_epoch
|
||||
|
||||
@cur_epoch.setter
|
||||
def cur_epoch(self, epoch: int):
|
||||
"""Set how many epochs have been processed.
|
||||
"""
|
||||
"""Set how many epochs have been processed."""
|
||||
# allow setter for training resumption
|
||||
self._cur_epoch = epoch
|
||||
|
||||
@property
|
||||
def cur_step(self):
|
||||
"""Returns how many iteration steps have been processed.
|
||||
"""
|
||||
"""Returns how many iteration steps have been processed."""
|
||||
return self._cur_step
|
||||
|
||||
@property
|
||||
|
@ -131,8 +128,7 @@ class Trainer:
|
|||
getattr(self._timer, action)(item, *args, **kwargs)
|
||||
|
||||
def _reset_states(self) -> None:
|
||||
"""Clear trainer states
|
||||
"""
|
||||
"""Clear trainer states"""
|
||||
self.states = dict()
|
||||
|
||||
def _call_hooks(self, func, output=None):
|
||||
|
@ -152,99 +148,122 @@ class Trainer:
|
|||
|
||||
@staticmethod
|
||||
def _should_display_progress(display_progress: bool):
|
||||
""" Only display progress on DP rank 0, TP rank 0 and PP last rank
|
||||
"""
|
||||
return display_progress and is_dp_rank_0() and is_tp_rank_0() and is_no_pp_or_last_stage()
|
||||
"""Only display progress on DP rank 0, TP rank 0 and PP last rank"""
|
||||
return (display_progress and is_dp_rank_0() and is_tp_rank_0()
|
||||
and is_no_pp_or_last_stage())
|
||||
|
||||
def _train_epoch(self,
|
||||
train_dataloader: DataLoader,
|
||||
epoch: int = None,
|
||||
display_progress: bool = False,
|
||||
return_output_label: bool = True):
|
||||
def _train_epoch(
|
||||
self,
|
||||
train_dataloader: DataLoader,
|
||||
epoch: int = None,
|
||||
display_progress: bool = False,
|
||||
return_output_label: bool = True,
|
||||
):
|
||||
# set training state
|
||||
self._engine.train()
|
||||
data_iter = iter(train_dataloader)
|
||||
progress = range(self._steps_per_epoch)
|
||||
if display_progress:
|
||||
if epoch is None:
|
||||
progress = tqdm(progress, desc='[Train]')
|
||||
progress = tqdm(progress, desc="[Train]")
|
||||
else:
|
||||
progress = tqdm(progress, desc=f'[Epoch {epoch} / Train]')
|
||||
progress = tqdm(progress, desc=f"[Epoch {epoch} / Train]")
|
||||
|
||||
self._call_hooks('before_train_epoch')
|
||||
self._call_timer(action='start', item='Train-epoch')
|
||||
self._call_hooks("before_train_epoch")
|
||||
self._call_timer(action="start", item="Train-epoch")
|
||||
for i in progress:
|
||||
self._call_hooks('before_train_iter')
|
||||
self._call_timer(action='start', item='Train-step')
|
||||
self._call_hooks("before_train_iter")
|
||||
self._call_timer(action="start", item="Train-step")
|
||||
|
||||
# run 1 training step
|
||||
self.engine.zero_grad()
|
||||
logits, label, loss = self.schedule.forward_backward_step(
|
||||
self.engine, data_iter, forward_only=False, return_loss=True, return_output_label=return_output_label)
|
||||
self.engine,
|
||||
data_iter,
|
||||
forward_only=False,
|
||||
return_loss=True,
|
||||
return_output_label=return_output_label,
|
||||
)
|
||||
self.engine.step()
|
||||
self._call_timer(action='stop', item='Train-step', keep_in_history=True)
|
||||
self._call_hooks('after_train_iter', output=(logits, label, loss))
|
||||
self._call_timer(action="stop",
|
||||
item="Train-step",
|
||||
keep_in_history=True)
|
||||
self._call_hooks("after_train_iter", output=(logits, label, loss))
|
||||
|
||||
self._cur_step += 1
|
||||
|
||||
if display_progress:
|
||||
if 'step_metrics' in self.states:
|
||||
progress.set_postfix(**self.states['step_metrics'])
|
||||
if "step_metrics" in self.states:
|
||||
progress.set_postfix(**self.states["step_metrics"])
|
||||
|
||||
# stop when max iter is reached
|
||||
if self._exceed_max_step():
|
||||
break
|
||||
|
||||
self._call_timer(action='stop', item='Train-epoch', keep_in_history=True)
|
||||
self._call_hooks('after_train_epoch')
|
||||
self._call_timer(action='reset', item='Train-epoch')
|
||||
self._call_timer(action="stop",
|
||||
item="Train-epoch",
|
||||
keep_in_history=True)
|
||||
self._call_hooks("after_train_epoch")
|
||||
self._call_timer(action="reset", item="Train-epoch")
|
||||
|
||||
def _eval(self,
|
||||
test_dataloader: DataLoader,
|
||||
epoch: int = None,
|
||||
display_progress: bool = False,
|
||||
return_output_label: bool = True):
|
||||
def _eval(
|
||||
self,
|
||||
test_dataloader: DataLoader,
|
||||
epoch: int = None,
|
||||
display_progress: bool = False,
|
||||
return_output_label: bool = True,
|
||||
):
|
||||
# switch engine status
|
||||
self._engine.eval()
|
||||
|
||||
data_iter = iter(test_dataloader)
|
||||
num_steps = len(test_dataloader)
|
||||
|
||||
self._call_hooks('before_test')
|
||||
self._call_hooks("before_test")
|
||||
# prepare progress bar
|
||||
progress = range(num_steps)
|
||||
if display_progress:
|
||||
desc = 'Evaluation'
|
||||
desc = "Evaluation"
|
||||
if epoch is not None:
|
||||
desc = '[Epoch %d / Test]' % epoch
|
||||
desc = "[Epoch %d / Test]" % epoch
|
||||
progress = tqdm(progress, desc=desc)
|
||||
|
||||
self._call_hooks('before_test_epoch')
|
||||
self._call_timer(action='start', item='Test-epoch')
|
||||
self._call_hooks("before_test_epoch")
|
||||
self._call_timer(action="start", item="Test-epoch")
|
||||
with torch.no_grad():
|
||||
for _ in progress:
|
||||
self._call_hooks('before_test_iter')
|
||||
self._call_timer(action='start', item='Test-step')
|
||||
self._call_hooks("before_test_iter")
|
||||
self._call_timer(action="start", item="Test-step")
|
||||
logits, label, loss = self.schedule.forward_backward_step(
|
||||
self.engine, data_iter, forward_only=True, return_loss=True, return_output_label=return_output_label)
|
||||
self._call_timer(action='stop', item='Test-step', keep_in_history=True)
|
||||
self._call_hooks('after_test_iter',
|
||||
self.engine,
|
||||
data_iter,
|
||||
forward_only=True,
|
||||
return_loss=True,
|
||||
return_output_label=return_output_label,
|
||||
)
|
||||
self._call_timer(action="stop",
|
||||
item="Test-step",
|
||||
keep_in_history=True)
|
||||
self._call_hooks("after_test_iter",
|
||||
output=(logits, label, loss))
|
||||
|
||||
if display_progress:
|
||||
if 'step_metrics' in self.states:
|
||||
progress.set_postfix(**self.states['step_metrics'])
|
||||
if "step_metrics" in self.states:
|
||||
progress.set_postfix(**self.states["step_metrics"])
|
||||
|
||||
self._call_timer(action='stop', item='Test-epoch', keep_in_history=True)
|
||||
self._call_hooks('after_test_epoch')
|
||||
self._call_hooks('after_test')
|
||||
self._call_timer(action='reset', item='Test-step')
|
||||
self._call_timer(action='reset', item='Test-epoch')
|
||||
self._call_timer(action="stop",
|
||||
item="Test-epoch",
|
||||
keep_in_history=True)
|
||||
self._call_hooks("after_test_epoch")
|
||||
self._call_hooks("after_test")
|
||||
self._call_timer(action="reset", item="Test-step")
|
||||
self._call_timer(action="reset", item="Test-epoch")
|
||||
|
||||
def _exceed_max_step(self):
|
||||
return self._max_steps is not None and self._cur_step >= self._max_steps
|
||||
|
||||
def fit(self,
|
||||
def fit(
|
||||
self,
|
||||
train_dataloader: DataLoader,
|
||||
epochs: int,
|
||||
max_steps: int = None,
|
||||
|
@ -253,7 +272,7 @@ class Trainer:
|
|||
hooks: List[BaseHook] = None,
|
||||
display_progress: bool = False,
|
||||
return_output_label: bool = True,
|
||||
):
|
||||
):
|
||||
"""Trains the model to fit training data.
|
||||
|
||||
:param train_dataloader: DataLoader in training
|
||||
|
@ -290,7 +309,9 @@ class Trainer:
|
|||
# reset hooks
|
||||
self._reset_states()
|
||||
if hooks is not None:
|
||||
assert isinstance(hooks, list), f'expected argument hooks be to list, but got {type(hooks)}'
|
||||
assert isinstance(
|
||||
hooks, list
|
||||
), f"expected argument hooks be to list, but got {type(hooks)}"
|
||||
else:
|
||||
hooks = []
|
||||
self.hooks = hooks
|
||||
|
@ -298,13 +319,16 @@ class Trainer:
|
|||
if self._verbose:
|
||||
for hook in self.hooks:
|
||||
self._logger.info(
|
||||
f'Using {hook.__class__.__name__} for training, priority = {hook.priority}', ranks=[0])
|
||||
self._logger.info("Lower value means higher priority for calling hook function", ranks=[0])
|
||||
self._call_hooks('after_hook_is_attached')
|
||||
f"Using {hook.__class__.__name__} for training, priority = {hook.priority}",
|
||||
ranks=[0],
|
||||
)
|
||||
self._logger.info(
|
||||
"Lower value means higher priority for calling hook function",
|
||||
ranks=[0])
|
||||
self._call_hooks("after_hook_is_attached")
|
||||
|
||||
# start train
|
||||
self._engine.train()
|
||||
self._call_hooks('before_train')
|
||||
self._call_hooks("before_train")
|
||||
|
||||
# recover step value if resuming training
|
||||
last_epoch = self._cur_epoch
|
||||
|
@ -317,16 +341,17 @@ class Trainer:
|
|||
train_dataloader=train_dataloader,
|
||||
epoch=epoch,
|
||||
display_progress=display_progress,
|
||||
return_output_label=return_output_label
|
||||
return_output_label=return_output_label,
|
||||
)
|
||||
|
||||
# start eval
|
||||
if should_test and epoch % test_interval == 0:
|
||||
self._eval(test_dataloader=test_dataloader,
|
||||
display_progress=display_progress,
|
||||
epoch=epoch,
|
||||
return_output_label=return_output_label
|
||||
)
|
||||
self._eval(
|
||||
test_dataloader=test_dataloader,
|
||||
display_progress=display_progress,
|
||||
epoch=epoch,
|
||||
return_output_label=return_output_label,
|
||||
)
|
||||
|
||||
self._cur_epoch += 1
|
||||
|
||||
|
@ -334,16 +359,19 @@ class Trainer:
|
|||
if self._exceed_max_step():
|
||||
self._logger.info(
|
||||
f"Max number of steps {max_steps} has been reached, training is stopped automatically",
|
||||
ranks=[0])
|
||||
ranks=[0],
|
||||
)
|
||||
break
|
||||
self._call_hooks('after_train')
|
||||
self._call_timer('reset', 'Train-epoch')
|
||||
self._call_hooks("after_train")
|
||||
self._call_timer("reset", "Train-epoch")
|
||||
|
||||
def evaluate(self,
|
||||
test_dataloader: DataLoader,
|
||||
hooks: List[BaseHook] = None,
|
||||
display_progress: bool = False,
|
||||
return_output_label: bool = True):
|
||||
def evaluate(
|
||||
self,
|
||||
test_dataloader: DataLoader,
|
||||
hooks: List[BaseHook] = None,
|
||||
display_progress: bool = False,
|
||||
return_output_label: bool = True,
|
||||
):
|
||||
"""Evaluates the model with testing data.
|
||||
|
||||
:param test_dataloader: DataLoader in testing
|
||||
|
@ -362,7 +390,9 @@ class Trainer:
|
|||
# reset hooks
|
||||
self._reset_states()
|
||||
if hooks is not None:
|
||||
assert isinstance(hooks, list), f'expected argument hooks be to list, but got {type(hooks)}'
|
||||
assert isinstance(
|
||||
hooks, list
|
||||
), f"expected argument hooks be to list, but got {type(hooks)}"
|
||||
else:
|
||||
hooks = []
|
||||
self.hooks = hooks
|
||||
|
@ -370,15 +400,20 @@ class Trainer:
|
|||
if self._verbose:
|
||||
for hook in self.hooks:
|
||||
self._logger.info(
|
||||
f'Using {hook.__class__.__name__} for training, priority = {hook.priority}', ranks=[0])
|
||||
self._logger.info("Lower value means higher priority for calling hook function", ranks=[0])
|
||||
self._call_hooks('after_hook_is_attached')
|
||||
f"Using {hook.__class__.__name__} for training, priority = {hook.priority}",
|
||||
ranks=[0],
|
||||
)
|
||||
self._logger.info(
|
||||
"Lower value means higher priority for calling hook function",
|
||||
ranks=[0])
|
||||
self._call_hooks("after_hook_is_attached")
|
||||
|
||||
# eval
|
||||
self._eval(test_dataloader=test_dataloader,
|
||||
display_progress=display_progress,
|
||||
return_output_label=return_output_label
|
||||
)
|
||||
self._eval(
|
||||
test_dataloader=test_dataloader,
|
||||
display_progress=display_progress,
|
||||
return_output_label=return_output_label,
|
||||
)
|
||||
|
||||
def predict(self, data: Union[Tensor, List[Tensor]]):
|
||||
"""Uses trained model to make a prediction for a tensor or a tensor list.
|
||||
|
@ -399,6 +434,8 @@ class Trainer:
|
|||
# for compatibility with schedule
|
||||
simple_dataloader = [(data, None)]
|
||||
data_iter = iter(simple_dataloader)
|
||||
output, _, _ = self.schedule.forward_backward_step(
|
||||
self.engine, data_iter, forward_only=True, return_loss=False)
|
||||
output, _, _ = self.schedule.forward_backward_step(self.engine,
|
||||
data_iter,
|
||||
forward_only=True,
|
||||
return_loss=False)
|
||||
return output
|
||||
|
|
|
@ -6,6 +6,7 @@ from pathlib import Path
|
|||
import pytest
|
||||
|
||||
from colossalai.context.config import Config
|
||||
from colossalai.builder import build_ophooks
|
||||
|
||||
|
||||
@pytest.mark.cpu
|
||||
|
@ -17,3 +18,10 @@ def test_load_config():
|
|||
assert config.train_data.dataset, 'cannot access grandchild attribute'
|
||||
assert isinstance(config.train_data.dataset.transform_pipeline[0], dict), \
|
||||
f'expected attribute transform_pipeline elements to be a dict, but found {type(config.train_data.dataset.transform_pipeline)}'
|
||||
|
||||
|
||||
@pytest.mark.cpu
|
||||
def test_load_ophooks():
|
||||
dict = {'type': 'MemTracerOpHook', 'niter': 2}
|
||||
ophook = build_ophooks(dict)
|
||||
assert ophook.niter() == 2
|
||||
|
|
Loading…
Reference in New Issue