fixed trainer

pull/27/head
Frank Lee 2021-11-11 17:41:45 +08:00
parent af88570f4b
commit 2e0b0b7699
2 changed files with 189 additions and 135 deletions

View File

@ -337,19 +337,6 @@ def initialize(config: Union[str, dict] = None,
optimizer = build_optimizer_wrapper(fp16_cfg, optimizer)
logger.info('Optimizer is created', ranks=[0])
lr_scheduler = None
if hasattr(gpc.config, 'lr_scheduler'):
# if hasattr(gpc.config, 'num_steps'):
# total_steps = gpc.config.num_steps
# elif hasattr(gpc.config, 'num_epochs'):
# total_steps = int(gpc.config.num_epochs * len(train_dataloader))
# else:
# raise Exception(
# 'Please specify training stopping criterion num_steps or num_epochs in your configuration.'
# )
lr_scheduler = build_lr_scheduler(gpc.config.lr_scheduler, optimizer)
logger.info('Learning rate scheduler is created', ranks=[0])
# pipeline or no pipeline schedule
if hasattr(gpc.config, 'fp16'):
amp_type = gpc.config.fp16.mode
@ -367,4 +354,4 @@ def initialize(config: Union[str, dict] = None,
else:
schedule = NoPipelineSchedule(amp_type=amp_type, amp_config=amp_cfg)
return model, train_dataloader, test_dataloader, criterion, optimizer, schedule, lr_scheduler
return model, train_dataloader, test_dataloader, criterion, optimizer, schedule

View File

@ -5,7 +5,10 @@ from typing import Optional
from typing import Union, List
import torch
import torch.nn as nn
from torch.nn.modules.loss import _Loss
from torch import Tensor
from torch.optim import Optimizer
from torch.utils.data import DataLoader
from tqdm import tqdm
@ -16,6 +19,8 @@ from colossalai.engine import Engine
from colossalai.logging import get_global_dist_logger
from colossalai.utils import get_global_multitimer, is_dp_rank_0, is_tp_rank_0, is_no_pp_or_last_stage
from colossalai.nn.data import DataParallelSampler
from colossalai.utils import MultiTimer
from colossalai.engine import Engine, BaseSchedule
class Trainer:
@ -31,42 +36,38 @@ class Trainer:
:type verbose: bool, optional
"""
def __init__(self,
engine: Engine,
hooks_cfg: Optional[Config] = None,
verbose: bool = False):
model: nn.Module,
optimizer: Optimizer,
loss: _Loss,
scheule: BaseSchedule,
verbose: bool = False,
timer: MultiTimer = None):
# training-ralated params
self._engine = engine
self._max_epochs = float('inf')
self._max_steps = float('inf')
self._model = model
self._optimizer = optimizer
self._loss = loss
self._engine = Engine(
model=model,
optimizer=optimizer,
schedule=scheule)
self._max_epochs = 0
self._cur_epoch = 0
self._max_steps = 0
self._cur_step = 0
# data-related params
self._train_dataloader = None
self._test_dataloader = None
self._steps_per_epoch = 0
# misc params
self._display_progress = False
self._logger = get_global_dist_logger()
self._verbose = verbose
# hooks can store states in this dict, and could be consumed by other hooks
self.states = {}
self.states = dict()
# build hooks
self.hooks = list()
if hooks_cfg is not None:
for cfg in hooks_cfg:
hook = build_hooks(cfg, self)
self.hooks.append(hook)
self.hooks.sort(key=lambda hook: hook.priority)
if self._verbose:
for hook in self.hooks:
self._logger.info(
f'build {hook.__class__.__name__} for train, priority = {hook.priority}', ranks=[0])
# timer
self._timer = get_global_multitimer()
# multi-timer for time benchmarking
self._timer = timer
@property
def cur_epoch(self):
@ -80,7 +81,50 @@ class Trainer:
"""
return self._cur_step
def call_hooks(self, func, output=None):
@property
def max_epochs(self):
return self._max_epochs
@property
def max_steps(self):
return self._max_steps
@property
def steps_per_epoch(self):
return self._steps_per_epoch
@property
def optimizer(self):
return self._optimizer
@property
def model(self):
return self._model
def set_epoch(self, epoch):
"""Sets current epoch number.
:param epoch: Epoch number to be set
:type epoch: int
"""
self._cur_epoch = epoch
def _set_current_step(self, epoch: int):
"""Sets current step number.
:param epoch: Step number to be set
:type epoch: int
"""
self._cur_step = epoch * self._steps_per_epoch
def _call_timer(self, action: str, item: str, *args, **kwargs) -> None:
if self._timer is not None:
getattr(self._timer, action)(item, *args, **kwargs)
def _reset_states(self) -> None:
self.states = dict()
def _call_hooks(self, func, output=None):
"""Calls specific hooks in the current time point.
:param func: A string represents the time point
@ -95,161 +139,185 @@ class Trainer:
else:
getattr(hook, func)(*output)
def exceed_max_step(self):
"""Checks whether the trainer exceeds the maximum number of runnning iterations.
"""
return self._cur_step >= self._max_steps
@staticmethod
def _should_display_progress(display_progress: bool):
return display_progress and is_dp_rank_0() and is_tp_rank_0() and is_no_pp_or_last_stage()
def set_epoch(self, epoch):
"""Sets current epoch number.
:param epoch: Epoch number to be set
:type epoch: int
"""
self._cur_epoch = epoch
def _recover_steps(self):
step = self.cur_step * self._engine.schedule.num_steps
self._cur_step = step
def _set_display_progress(self, display_progress: bool):
self._display_progress = display_progress and is_dp_rank_0(
) and is_tp_rank_0() and is_no_pp_or_last_stage()
def _train_epoch(self, epoch: int = None):
def _train_epoch(self,
train_dataloader: DataLoader,
epoch: int = None,
max_steps: int = None,
display_progress: bool = False):
# set sampler epoch
if epoch is not None and \
hasattr(self._engine.train_dataloader, 'sampler') and \
isinstance(self._engine.train_dataloader.sampler, DataParallelSampler):
self._engine.train_dataloader.sampler.set_epoch(epoch)
# set training state
self._engine.train()
data_iter = iter(train_dataloader)
progress = range(self._engine.schedule.num_steps)
if self._display_progress:
progress = range(len(train_dataloader))
if display_progress:
if epoch is None:
progress = tqdm(progress, desc='[Train]')
else:
progress = tqdm(progress, desc=f'[Epoch {epoch} train]')
# train 1 epoch
self.call_hooks('before_train_epoch')
self._timer.start('train-epoch')
for _ in progress:
self.call_hooks('before_train_iter')
self._timer.start('train-step')
logits, label, loss = self._engine.step()
self._timer.stop('train-step', keep_in_history=True)
self.call_hooks('after_train_iter', output=(logits, label, loss))
self._call_hooks('before_train_epoch')
self._call_timer(action='start', item='train-epoch')
for i in progress:
self._call_hooks('before_train_iter')
self._call_timer(action='start', item='train-step')
# run 1 training step
if i == len(train_dataloader) - 1:
is_last_iteration = True
else:
is_last_iteration = False
logits, label, loss = self._engine.step(data_iter, self._model, self._criterion, self._optimizer, is_last_iteration)
self._call_timer(action='stop', item='train-step', keep_in_history=True)
self._call_hooks('after_train_iter', output=(logits, label, loss))
self._cur_step += 1
if self.exceed_max_step():
# stop when max iter is reached
if self._exceed_max_step():
break
self._engine.complete()
self._timer.stop('train-epoch', keep_in_history=True)
self.call_hooks('after_train_epoch')
self._timer.reset('train-step')
self._call_timer(action='stop', item='train-epoch', keep_in_history=True)
self._call_hooks('after_train_epoch')
self._call_timer(action='reset', item='train-step')
def _eval(self,
test_dataloader: DataLoader,
display_progress: bool = False,
epoch: int = None,
return_loss: bool = True):
# switch engine status
self._engine.eval()
self.call_hooks('before_test')
data_iter = iter(test_dataloader)
self._call_hooks('before_test')
with torch.no_grad():
# prepare progress bar
progress = range(self._engine.schedule.num_steps)
if self._display_progress:
progress = range(len(test_dataloader))
if display_progress:
desc = 'Evaluation'
if epoch is not None:
desc = '[Epoch %d val]' % epoch
progress = tqdm(progress, desc=desc)
self.call_hooks('before_test_epoch')
self._timer.start('test-epoch')
self._call_hooks('before_test_epoch')
self._call_timer(action='start', item='test-epoch')
for _ in progress:
self.call_hooks('before_test_iter')
self._timer.start('test-step')
logits, label, loss = self._engine.step(
return_loss=return_loss)
self._timer.stop('test-step', keep_in_history=True)
self.call_hooks('after_test_iter',
self._call_hooks('before_test_iter')
self._call_timer(action='start', item='test-step')
logits, label, loss = self._engine.step(data_iter, self._model, self._criterion, return_loss=return_loss)
self._call_timer(action='stop', item='test-step', keep_in_history=True)
self._call_hooks('after_test_iter',
output=(logits, label, loss))
self._timer.stop('test-epoch', keep_in_history=True)
self.call_hooks('after_test_epoch')
self.call_hooks('after_test')
self._timer.reset('test-step')
self._timer.reset('test-epoch')
self._call_timer(action='stop', item='test-epoch', keep_in_history=True)
self._call_hooks('after_test_epoch')
self._call_hooks('after_test')
self._call_timer(action='reset', item='test-step')
self._call_timer(action='reset', item='test-epoch')
def _exceed_max_step(self):
return self._max_steps is not None and self._cur_step > self._max_steps
def fit(self,
train_dataloader: DataLoader,
test_dataloader: DataLoader = None,
max_epochs: int = None,
epochs: int,
max_steps: int = None,
test_dataloader: DataLoader = None,
test_interval: int = 1,
display_progress: bool = False):
hooks_cfg: dict = None,
display_progress: bool = False,
gradient_accumulation: int = 1):
"""Trains the model to fit training data.
:param train_dataloader: DataLoader in training
:param test_dataloader: DataLoader in testing
:param max_epochs: Maximum number of epoches
:param max_steps: Maximum number of running iterations
:param test_dataloader: DataLoader in testing
:param test_interval: Interval of testing
:param hooks_cfg: A list of hook configuration
:param display_progress: If True, the training progress will be printed
:type train_dataloader: DataLoader
:type test_dataloader: DataLoader
:type max_epochs: int
:type max_steps: int
:type test_dataloader: DataLoader
:type test_interval: int
:type hooks_cfg: dict
:type display_progress: bool
"""
# prepare dataloaders
self._train_dataloader = train_dataloader
self._engine.set_dataloader(self._train_dataloader, train=True)
self._engine.train()
should_test = False
if test_dataloader is not None:
self._test_dataloader = test_dataloader
self._engine.set_dataloader(self._test_dataloader, train=False)
should_test = True
# decide the
if max_epochs is not None:
self._max_epochs = max_epochs
if max_steps is not None:
# set epochs and steps
self._steps_per_epoch = len(train_dataloader) // gradient_accumulation
self._max_steps = max_steps
self._set_display_progress(display_progress)
# start train
self.call_hooks('before_train')
self._max_epochs = epochs
# recover step value if resuming training
if self.cur_epoch != 0:
self._recover_steps()
last_epoch = self._cur_epoch
if self.cur_epoch != 0:
self._set_current_step(last_epoch)
for epoch in range(last_epoch, self._max_epochs):
# check if testing is required
should_test = False
if test_dataloader is not None:
should_test = True
display_progress = self._should_display_progress(display_progress)
# reset hooks
self._reset_states()
self.hooks = list()
# build hooks
if hooks_cfg is not None:
for cfg in hooks_cfg:
hook = build_hooks(cfg, self)
self.hooks.append(hook)
self.hooks.sort(key=lambda hook: hook.priority)
if self._verbose:
for hook in self.hooks:
self._logger.info(
f'build {hook.__class__.__name__} for train, priority = {hook.priority}', ranks=[0])
# start train
self._engine.train()
self._call_hooks('before_train')
for epoch in range(last_epoch, epochs):
# train for one epoch
self._train_epoch(epoch)
self._train_epoch(
train_dataloader=train_dataloader,
epoch=None,
max_steps=max_steps,
display_progress=display_progress
)
# start eval
if should_test and epoch % test_interval == 0:
self._eval(epoch, return_loss=True)
self._eval(test_dataloader=test_dataloader,
display_progress=display_progress,
epoch=epoch,
return_loss=True
)
self._cur_epoch += 1
# check for termination
if self.exceed_max_step():
if self._exceed_max_step():
self._logger.info(
f"Max number of steps {self._max_steps} has been reached, training is stopped automatically")
f"Max number of steps {max_steps} has been reached, training is stopped automatically")
break
self.call_hooks('after_train')
self._call_hooks('after_train')
self._timer.reset('train-epoch')
def evaluate(self,
@ -262,15 +330,14 @@ class Trainer:
:type test_dataloader: DataLoader
:type display_progress: bool, optional
"""
# set dataloader
self._test_dataloader = test_dataloader
self._engine.set_dataloader(self._test_dataloader, train=True)
# set
self._set_display_progress(display_progress)
# set display
display_progress = self._should_display_progress(display_progress)
# eval
self._eval(return_loss=True)
self._eval(test_dataloader=test_dataloader,
display_progress=display_progress,
return_loss=True
)
def predict(self, data: Union[Tensor, List[Tensor]]):
"""Uses trained model to make a prediction for a tensor or a tensor list.
@ -290,8 +357,8 @@ class Trainer:
# prepare a list of (data, label) to make it iterable
# for compatibility with schedule
simple_dataloader = [(data, None)]
self._engine.set_dataloader(simple_dataloader)
output, _, _ = self._engine.step(return_loss=False)
data_iter = iter(simple_dataloader)
output, _, _ = self._engine.step(data_iter, self._model, self._criterion, return_loss=False)
return output
def save(self, path: str, suffix: str = ''):