From 2e0b0b76990e8d4e337add483d878c0f61cf5097 Mon Sep 17 00:00:00 2001 From: Frank Lee Date: Thu, 11 Nov 2021 17:41:45 +0800 Subject: [PATCH] fixed trainer --- colossalai/initialize.py | 15 +- colossalai/trainer/_trainer.py | 309 ++++++++++++++++++++------------- 2 files changed, 189 insertions(+), 135 deletions(-) diff --git a/colossalai/initialize.py b/colossalai/initialize.py index 9405fc8fe..768a3a6ed 100644 --- a/colossalai/initialize.py +++ b/colossalai/initialize.py @@ -337,19 +337,6 @@ def initialize(config: Union[str, dict] = None, optimizer = build_optimizer_wrapper(fp16_cfg, optimizer) logger.info('Optimizer is created', ranks=[0]) - lr_scheduler = None - if hasattr(gpc.config, 'lr_scheduler'): - # if hasattr(gpc.config, 'num_steps'): - # total_steps = gpc.config.num_steps - # elif hasattr(gpc.config, 'num_epochs'): - # total_steps = int(gpc.config.num_epochs * len(train_dataloader)) - # else: - # raise Exception( - # 'Please specify training stopping criterion num_steps or num_epochs in your configuration.' - # ) - lr_scheduler = build_lr_scheduler(gpc.config.lr_scheduler, optimizer) - logger.info('Learning rate scheduler is created', ranks=[0]) - # pipeline or no pipeline schedule if hasattr(gpc.config, 'fp16'): amp_type = gpc.config.fp16.mode @@ -367,4 +354,4 @@ def initialize(config: Union[str, dict] = None, else: schedule = NoPipelineSchedule(amp_type=amp_type, amp_config=amp_cfg) - return model, train_dataloader, test_dataloader, criterion, optimizer, schedule, lr_scheduler + return model, train_dataloader, test_dataloader, criterion, optimizer, schedule diff --git a/colossalai/trainer/_trainer.py b/colossalai/trainer/_trainer.py index a8706ceb1..a202c3c8a 100644 --- a/colossalai/trainer/_trainer.py +++ b/colossalai/trainer/_trainer.py @@ -5,7 +5,10 @@ from typing import Optional from typing import Union, List import torch +import torch.nn as nn +from torch.nn.modules.loss import _Loss from torch import Tensor +from torch.optim import Optimizer from torch.utils.data import DataLoader from tqdm import tqdm @@ -16,6 +19,8 @@ from colossalai.engine import Engine from colossalai.logging import get_global_dist_logger from colossalai.utils import get_global_multitimer, is_dp_rank_0, is_tp_rank_0, is_no_pp_or_last_stage from colossalai.nn.data import DataParallelSampler +from colossalai.utils import MultiTimer +from colossalai.engine import Engine, BaseSchedule class Trainer: @@ -31,42 +36,38 @@ class Trainer: :type verbose: bool, optional """ def __init__(self, - engine: Engine, - hooks_cfg: Optional[Config] = None, - verbose: bool = False): + model: nn.Module, + optimizer: Optimizer, + loss: _Loss, + scheule: BaseSchedule, + verbose: bool = False, + timer: MultiTimer = None): # training-ralated params - self._engine = engine - self._max_epochs = float('inf') - self._max_steps = float('inf') + self._model = model + self._optimizer = optimizer + self._loss = loss + self._engine = Engine( + model=model, + optimizer=optimizer, + schedule=scheule) + self._max_epochs = 0 self._cur_epoch = 0 + self._max_steps = 0 self._cur_step = 0 - - # data-related params - self._train_dataloader = None - self._test_dataloader = None + self._steps_per_epoch = 0 # misc params - self._display_progress = False self._logger = get_global_dist_logger() self._verbose = verbose # hooks can store states in this dict, and could be consumed by other hooks - self.states = {} + self.states = dict() # build hooks self.hooks = list() - if hooks_cfg is not None: - for cfg in hooks_cfg: - hook = build_hooks(cfg, self) - self.hooks.append(hook) - self.hooks.sort(key=lambda hook: hook.priority) - if self._verbose: - for hook in self.hooks: - self._logger.info( - f'build {hook.__class__.__name__} for train, priority = {hook.priority}', ranks=[0]) - # timer - self._timer = get_global_multitimer() + # multi-timer for time benchmarking + self._timer = timer @property def cur_epoch(self): @@ -80,7 +81,50 @@ class Trainer: """ return self._cur_step - def call_hooks(self, func, output=None): + @property + def max_epochs(self): + return self._max_epochs + + @property + def max_steps(self): + return self._max_steps + + @property + def steps_per_epoch(self): + return self._steps_per_epoch + + @property + def optimizer(self): + return self._optimizer + + @property + def model(self): + return self._model + + def set_epoch(self, epoch): + """Sets current epoch number. + + :param epoch: Epoch number to be set + :type epoch: int + """ + self._cur_epoch = epoch + + def _set_current_step(self, epoch: int): + """Sets current step number. + + :param epoch: Step number to be set + :type epoch: int + """ + self._cur_step = epoch * self._steps_per_epoch + + def _call_timer(self, action: str, item: str, *args, **kwargs) -> None: + if self._timer is not None: + getattr(self._timer, action)(item, *args, **kwargs) + + def _reset_states(self) -> None: + self.states = dict() + + def _call_hooks(self, func, output=None): """Calls specific hooks in the current time point. :param func: A string represents the time point @@ -95,161 +139,185 @@ class Trainer: else: getattr(hook, func)(*output) - def exceed_max_step(self): - """Checks whether the trainer exceeds the maximum number of runnning iterations. - """ - return self._cur_step >= self._max_steps + @staticmethod + def _should_display_progress(display_progress: bool): + return display_progress and is_dp_rank_0() and is_tp_rank_0() and is_no_pp_or_last_stage() - def set_epoch(self, epoch): - """Sets current epoch number. - - :param epoch: Epoch number to be set - :type epoch: int - """ - self._cur_epoch = epoch - - def _recover_steps(self): - step = self.cur_step * self._engine.schedule.num_steps - self._cur_step = step - - def _set_display_progress(self, display_progress: bool): - self._display_progress = display_progress and is_dp_rank_0( - ) and is_tp_rank_0() and is_no_pp_or_last_stage() - - def _train_epoch(self, epoch: int = None): + def _train_epoch(self, + train_dataloader: DataLoader, + epoch: int = None, + max_steps: int = None, + display_progress: bool = False): # set sampler epoch if epoch is not None and \ hasattr(self._engine.train_dataloader, 'sampler') and \ isinstance(self._engine.train_dataloader.sampler, DataParallelSampler): self._engine.train_dataloader.sampler.set_epoch(epoch) + # set training state self._engine.train() + data_iter = iter(train_dataloader) - progress = range(self._engine.schedule.num_steps) - if self._display_progress: + progress = range(len(train_dataloader)) + if display_progress: if epoch is None: progress = tqdm(progress, desc='[Train]') else: progress = tqdm(progress, desc=f'[Epoch {epoch} train]') # train 1 epoch - self.call_hooks('before_train_epoch') - self._timer.start('train-epoch') - for _ in progress: - self.call_hooks('before_train_iter') - self._timer.start('train-step') - logits, label, loss = self._engine.step() - self._timer.stop('train-step', keep_in_history=True) - self.call_hooks('after_train_iter', output=(logits, label, loss)) + self._call_hooks('before_train_epoch') + self._call_timer(action='start', item='train-epoch') + for i in progress: + self._call_hooks('before_train_iter') + self._call_timer(action='start', item='train-step') + + # run 1 training step + if i == len(train_dataloader) - 1: + is_last_iteration = True + else: + is_last_iteration = False + logits, label, loss = self._engine.step(data_iter, self._model, self._criterion, self._optimizer, is_last_iteration) + self._call_timer(action='stop', item='train-step', keep_in_history=True) + self._call_hooks('after_train_iter', output=(logits, label, loss)) self._cur_step += 1 - if self.exceed_max_step(): - # stop when max iter is reached + # stop when max iter is reached + if self._exceed_max_step(): break - self._engine.complete() - self._timer.stop('train-epoch', keep_in_history=True) - self.call_hooks('after_train_epoch') - self._timer.reset('train-step') + + self._call_timer(action='stop', item='train-epoch', keep_in_history=True) + self._call_hooks('after_train_epoch') + self._call_timer(action='reset', item='train-step') def _eval(self, + test_dataloader: DataLoader, + display_progress: bool = False, epoch: int = None, return_loss: bool = True): # switch engine status self._engine.eval() - self.call_hooks('before_test') + data_iter = iter(test_dataloader) + + self._call_hooks('before_test') with torch.no_grad(): # prepare progress bar - progress = range(self._engine.schedule.num_steps) - if self._display_progress: + progress = range(len(test_dataloader)) + if display_progress: desc = 'Evaluation' if epoch is not None: desc = '[Epoch %d val]' % epoch progress = tqdm(progress, desc=desc) - self.call_hooks('before_test_epoch') - self._timer.start('test-epoch') + self._call_hooks('before_test_epoch') + self._call_timer(action='start', item='test-epoch') for _ in progress: - self.call_hooks('before_test_iter') - self._timer.start('test-step') - logits, label, loss = self._engine.step( - return_loss=return_loss) - self._timer.stop('test-step', keep_in_history=True) - self.call_hooks('after_test_iter', - output=(logits, label, loss)) - self._timer.stop('test-epoch', keep_in_history=True) - self.call_hooks('after_test_epoch') - self.call_hooks('after_test') - self._timer.reset('test-step') - self._timer.reset('test-epoch') + self._call_hooks('before_test_iter') + self._call_timer(action='start', item='test-step') + logits, label, loss = self._engine.step(data_iter, self._model, self._criterion, return_loss=return_loss) + self._call_timer(action='stop', item='test-step', keep_in_history=True) + self._call_hooks('after_test_iter', + output=(logits, label, loss)) + self._call_timer(action='stop', item='test-epoch', keep_in_history=True) + self._call_hooks('after_test_epoch') + self._call_hooks('after_test') + self._call_timer(action='reset', item='test-step') + self._call_timer(action='reset', item='test-epoch') + + def _exceed_max_step(self): + return self._max_steps is not None and self._cur_step > self._max_steps + def fit(self, train_dataloader: DataLoader, - test_dataloader: DataLoader = None, - max_epochs: int = None, + epochs: int, max_steps: int = None, + test_dataloader: DataLoader = None, test_interval: int = 1, - display_progress: bool = False): + hooks_cfg: dict = None, + display_progress: bool = False, + gradient_accumulation: int = 1): """Trains the model to fit training data. :param train_dataloader: DataLoader in training - :param test_dataloader: DataLoader in testing :param max_epochs: Maximum number of epoches :param max_steps: Maximum number of running iterations + :param test_dataloader: DataLoader in testing :param test_interval: Interval of testing + :param hooks_cfg: A list of hook configuration :param display_progress: If True, the training progress will be printed :type train_dataloader: DataLoader - :type test_dataloader: DataLoader :type max_epochs: int :type max_steps: int + :type test_dataloader: DataLoader :type test_interval: int + :type hooks_cfg: dict :type display_progress: bool """ - # prepare dataloaders - self._train_dataloader = train_dataloader - self._engine.set_dataloader(self._train_dataloader, train=True) - self._engine.train() - - should_test = False - if test_dataloader is not None: - self._test_dataloader = test_dataloader - self._engine.set_dataloader(self._test_dataloader, train=False) - should_test = True - - # decide the - if max_epochs is not None: - self._max_epochs = max_epochs - if max_steps is not None: - self._max_steps = max_steps - self._set_display_progress(display_progress) - - # start train - self.call_hooks('before_train') + # set epochs and steps + self._steps_per_epoch = len(train_dataloader) // gradient_accumulation + self._max_steps = max_steps + self._max_epochs = epochs # recover step value if resuming training - if self.cur_epoch != 0: - self._recover_steps() - last_epoch = self._cur_epoch + if self.cur_epoch != 0: + self._set_current_step(last_epoch) - for epoch in range(last_epoch, self._max_epochs): + # check if testing is required + should_test = False + if test_dataloader is not None: + should_test = True + + display_progress = self._should_display_progress(display_progress) + + # reset hooks + self._reset_states() + self.hooks = list() + + # build hooks + if hooks_cfg is not None: + for cfg in hooks_cfg: + hook = build_hooks(cfg, self) + self.hooks.append(hook) + self.hooks.sort(key=lambda hook: hook.priority) + if self._verbose: + for hook in self.hooks: + self._logger.info( + f'build {hook.__class__.__name__} for train, priority = {hook.priority}', ranks=[0]) + + # start train + self._engine.train() + self._call_hooks('before_train') + + for epoch in range(last_epoch, epochs): # train for one epoch - self._train_epoch(epoch) + self._train_epoch( + train_dataloader=train_dataloader, + epoch=None, + max_steps=max_steps, + display_progress=display_progress + ) # start eval if should_test and epoch % test_interval == 0: - self._eval(epoch, return_loss=True) + self._eval(test_dataloader=test_dataloader, + display_progress=display_progress, + epoch=epoch, + return_loss=True + ) self._cur_epoch += 1 # check for termination - if self.exceed_max_step(): + if self._exceed_max_step(): self._logger.info( - f"Max number of steps {self._max_steps} has been reached, training is stopped automatically") + f"Max number of steps {max_steps} has been reached, training is stopped automatically") break - self.call_hooks('after_train') + self._call_hooks('after_train') self._timer.reset('train-epoch') def evaluate(self, @@ -262,15 +330,14 @@ class Trainer: :type test_dataloader: DataLoader :type display_progress: bool, optional """ - # set dataloader - self._test_dataloader = test_dataloader - self._engine.set_dataloader(self._test_dataloader, train=True) - - # set - self._set_display_progress(display_progress) + # set display + display_progress = self._should_display_progress(display_progress) # eval - self._eval(return_loss=True) + self._eval(test_dataloader=test_dataloader, + display_progress=display_progress, + return_loss=True + ) def predict(self, data: Union[Tensor, List[Tensor]]): """Uses trained model to make a prediction for a tensor or a tensor list. @@ -290,8 +357,8 @@ class Trainer: # prepare a list of (data, label) to make it iterable # for compatibility with schedule simple_dataloader = [(data, None)] - self._engine.set_dataloader(simple_dataloader) - output, _, _ = self._engine.step(return_loss=False) + data_iter = iter(simple_dataloader) + output, _, _ = self._engine.step(data_iter, self._model, self._criterion, return_loss=False) return output def save(self, path: str, suffix: str = ''):