[doc] improved assertion messages in trainer (#873)

pull/874/head^2
Frank Lee 3 years ago committed by GitHub
parent 7a64fae33a
commit 1c34382678
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1,13 +1,9 @@
from typing import Union, List from typing import Union, List, Any
from colossalai.context.parallel_mode import ParallelMode
import torch import torch
from torch import Tensor
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from tqdm import tqdm from tqdm import tqdm
from colossalai.core import global_context as gpc
from colossalai.engine import Engine from colossalai.engine import Engine
from colossalai.logging import DistributedLogger from colossalai.logging import DistributedLogger
from colossalai.utils import MultiTimer from colossalai.utils import MultiTimer
@ -53,6 +49,7 @@ class Trainer:
`Training with engine and trainer <https://www.colossalai.org/docs/basics/engine_trainer>`_ `Training with engine and trainer <https://www.colossalai.org/docs/basics/engine_trainer>`_
and `ColossalAI-Examples <https://github.com/hpcaitech/ColossalAI-Examples/tree/main>`_. and `ColossalAI-Examples <https://github.com/hpcaitech/ColossalAI-Examples/tree/main>`_.
""" """
def __init__( def __init__(
self, self,
engine: Engine, engine: Engine,
@ -154,8 +151,7 @@ class Trainer:
@staticmethod @staticmethod
def _should_display_progress(display_progress: bool): def _should_display_progress(display_progress: bool):
"""Only display progress on DP rank 0, TP rank 0 and PP last rank""" """Only display progress on DP rank 0, TP rank 0 and PP last rank"""
return (display_progress and is_dp_rank_0() and is_tp_rank_0() return (display_progress and is_dp_rank_0() and is_tp_rank_0() and is_no_pp_or_last_stage())
and is_no_pp_or_last_stage())
def _train_epoch( def _train_epoch(
self, self,
@ -189,9 +185,7 @@ class Trainer:
return_output_label=return_output_label, return_output_label=return_output_label,
) )
self.engine.step() self.engine.step()
self._call_timer(action="stop", self._call_timer(action="stop", item="Train-step", keep_in_history=True)
item="Train-step",
keep_in_history=True)
self._call_hooks("after_train_iter", output=(logits, label, loss)) self._call_hooks("after_train_iter", output=(logits, label, loss))
self._cur_step += 1 self._cur_step += 1
@ -204,9 +198,7 @@ class Trainer:
if self._exceed_max_step(): if self._exceed_max_step():
break break
self._call_timer(action="stop", self._call_timer(action="stop", item="Train-epoch", keep_in_history=True)
item="Train-epoch",
keep_in_history=True)
self._call_hooks("after_train_epoch") self._call_hooks("after_train_epoch")
self._call_timer(action="reset", item="Train-epoch") self._call_timer(action="reset", item="Train-epoch")
@ -244,19 +236,14 @@ class Trainer:
return_loss=True, return_loss=True,
return_output_label=return_output_label, return_output_label=return_output_label,
) )
self._call_timer(action="stop", self._call_timer(action="stop", item="Test-step", keep_in_history=True)
item="Test-step", self._call_hooks("after_test_iter", output=(logits, label, loss))
keep_in_history=True)
self._call_hooks("after_test_iter",
output=(logits, label, loss))
if display_progress: if display_progress:
if "step_metrics" in self.states: if "step_metrics" in self.states:
progress.set_postfix(**self.states["step_metrics"]) progress.set_postfix(**self.states["step_metrics"])
self._call_timer(action="stop", self._call_timer(action="stop", item="Test-epoch", keep_in_history=True)
item="Test-epoch",
keep_in_history=True)
self._call_hooks("after_test_epoch") self._call_hooks("after_test_epoch")
self._call_hooks("after_test") self._call_hooks("after_test")
self._call_timer(action="reset", item="Test-step") self._call_timer(action="reset", item="Test-step")
@ -303,9 +290,11 @@ class Trainer:
# reset hooks # reset hooks
self._reset_states() self._reset_states()
if hooks is not None: if hooks is not None:
assert isinstance( assert isinstance(hooks, list), f"expected argument hooks be to list, but got {type(hooks)}"
hooks, list
), f"expected argument hooks be to list, but got {type(hooks)}" for hook in hooks:
assert isinstance(hook, BaseHook), \
f'expected the hook to be of type BaseHook, but got {type(hook)}'
else: else:
hooks = [] hooks = []
self.hooks = hooks self.hooks = hooks
@ -316,9 +305,7 @@ class Trainer:
f"Using {hook.__class__.__name__} for training, priority = {hook.priority}", f"Using {hook.__class__.__name__} for training, priority = {hook.priority}",
ranks=[0], ranks=[0],
) )
self._logger.info( self._logger.info("Lower value means higher priority for calling hook function", ranks=[0])
"Lower value means higher priority for calling hook function",
ranks=[0])
self._call_hooks("after_hook_is_attached") self._call_hooks("after_hook_is_attached")
self._engine.train() self._engine.train()
@ -381,9 +368,7 @@ class Trainer:
# reset hooks # reset hooks
self._reset_states() self._reset_states()
if hooks is not None: if hooks is not None:
assert isinstance( assert isinstance(hooks, list), f"expected argument hooks be to list, but got {type(hooks)}"
hooks, list
), f"expected argument hooks be to list, but got {type(hooks)}"
else: else:
hooks = [] hooks = []
self.hooks = hooks self.hooks = hooks
@ -394,9 +379,7 @@ class Trainer:
f"Using {hook.__class__.__name__} for training, priority = {hook.priority}", f"Using {hook.__class__.__name__} for training, priority = {hook.priority}",
ranks=[0], ranks=[0],
) )
self._logger.info( self._logger.info("Lower value means higher priority for calling hook function", ranks=[0])
"Lower value means higher priority for calling hook function",
ranks=[0])
self._call_hooks("after_hook_is_attached") self._call_hooks("after_hook_is_attached")
# eval # eval
@ -406,7 +389,7 @@ class Trainer:
return_output_label=return_output_label, return_output_label=return_output_label,
) )
def predict(self, data: Union[Tensor, List[Tensor]]): def predict(self, data: Union[Any, List[Any]]):
"""Uses trained model to make a prediction for a tensor or a tensor list. """Uses trained model to make a prediction for a tensor or a tensor list.
Args: Args:
@ -416,17 +399,11 @@ class Trainer:
:class:`torch.tensor`: The output of model as the prediction :class:`torch.tensor`: The output of model as the prediction
""" """
# predict without labels # predict without labels
if isinstance(data, (list, tuple)):
assert isinstance(data[0], Tensor)
else:
assert isinstance(data, Tensor)
self._engine.eval() self._engine.eval()
# prepare a list of (data, label) to make it iterable # prepare a list of (data, label) to make it iterable
# for compatibility with schedule # for compatibility with schedule
simple_dataloader = [(data, None)] simple_dataloader = [(data, None)]
data_iter = iter(simple_dataloader) data_iter = iter(simple_dataloader)
output, _, _ = self.engine.execute_schedule(data_iter, output, _, _ = self.engine.execute_schedule(data_iter, forward_only=True, return_loss=False)
forward_only=True,
return_loss=False)
return output return output

Loading…
Cancel
Save