Browse Source

[doc] improved assertion messages in trainer (#873)

pull/874/head^2
Frank Lee 3 years ago committed by GitHub
parent
commit
1c34382678
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 115
      colossalai/trainer/_trainer.py

115
colossalai/trainer/_trainer.py

@ -1,13 +1,9 @@
from typing import Union, List
from colossalai.context.parallel_mode import ParallelMode
from typing import Union, List, Any
import torch
from torch import Tensor
from torch.utils.data import DataLoader
from tqdm import tqdm
from colossalai.core import global_context as gpc
from colossalai.engine import Engine
from colossalai.logging import DistributedLogger
from colossalai.utils import MultiTimer
@ -53,11 +49,12 @@ class Trainer:
`Training with engine and trainer <https://www.colossalai.org/docs/basics/engine_trainer>`_
and `ColossalAI-Examples <https://github.com/hpcaitech/ColossalAI-Examples/tree/main>`_.
"""
def __init__(
self,
engine: Engine,
timer: MultiTimer = None,
logger: DistributedLogger = None,
self,
engine: Engine,
timer: MultiTimer = None,
logger: DistributedLogger = None,
):
# training-ralated params
self._engine = engine
@ -154,15 +151,14 @@ class Trainer:
@staticmethod
def _should_display_progress(display_progress: bool):
"""Only display progress on DP rank 0, TP rank 0 and PP last rank"""
return (display_progress and is_dp_rank_0() and is_tp_rank_0()
and is_no_pp_or_last_stage())
return (display_progress and is_dp_rank_0() and is_tp_rank_0() and is_no_pp_or_last_stage())
def _train_epoch(
self,
train_dataloader: DataLoader,
epoch: int = None,
display_progress: bool = False,
return_output_label: bool = True,
self,
train_dataloader: DataLoader,
epoch: int = None,
display_progress: bool = False,
return_output_label: bool = True,
):
# set training state
self._engine.train()
@ -189,9 +185,7 @@ class Trainer:
return_output_label=return_output_label,
)
self.engine.step()
self._call_timer(action="stop",
item="Train-step",
keep_in_history=True)
self._call_timer(action="stop", item="Train-step", keep_in_history=True)
self._call_hooks("after_train_iter", output=(logits, label, loss))
self._cur_step += 1
@ -204,18 +198,16 @@ class Trainer:
if self._exceed_max_step():
break
self._call_timer(action="stop",
item="Train-epoch",
keep_in_history=True)
self._call_timer(action="stop", item="Train-epoch", keep_in_history=True)
self._call_hooks("after_train_epoch")
self._call_timer(action="reset", item="Train-epoch")
def _eval(
self,
test_dataloader: DataLoader,
epoch: int = None,
display_progress: bool = False,
return_output_label: bool = True,
self,
test_dataloader: DataLoader,
epoch: int = None,
display_progress: bool = False,
return_output_label: bool = True,
):
# switch engine status
self._engine.eval()
@ -244,19 +236,14 @@ class Trainer:
return_loss=True,
return_output_label=return_output_label,
)
self._call_timer(action="stop",
item="Test-step",
keep_in_history=True)
self._call_hooks("after_test_iter",
output=(logits, label, loss))
self._call_timer(action="stop", item="Test-step", keep_in_history=True)
self._call_hooks("after_test_iter", output=(logits, label, loss))
if display_progress:
if "step_metrics" in self.states:
progress.set_postfix(**self.states["step_metrics"])
self._call_timer(action="stop",
item="Test-epoch",
keep_in_history=True)
self._call_timer(action="stop", item="Test-epoch", keep_in_history=True)
self._call_hooks("after_test_epoch")
self._call_hooks("after_test")
self._call_timer(action="reset", item="Test-step")
@ -266,15 +253,15 @@ class Trainer:
return self._max_steps is not None and self._cur_step >= self._max_steps
def fit(
self,
train_dataloader: DataLoader,
epochs: int,
max_steps: int = None,
test_dataloader: DataLoader = None,
test_interval: int = 1,
hooks: List[BaseHook] = None,
display_progress: bool = False,
return_output_label: bool = True,
self,
train_dataloader: DataLoader,
epochs: int,
max_steps: int = None,
test_dataloader: DataLoader = None,
test_interval: int = 1,
hooks: List[BaseHook] = None,
display_progress: bool = False,
return_output_label: bool = True,
):
r"""Trains the model to fit training data.
@ -303,9 +290,11 @@ class Trainer:
# reset hooks
self._reset_states()
if hooks is not None:
assert isinstance(
hooks, list
), f"expected argument hooks be to list, but got {type(hooks)}"
assert isinstance(hooks, list), f"expected argument hooks be to list, but got {type(hooks)}"
for hook in hooks:
assert isinstance(hook, BaseHook), \
f'expected the hook to be of type BaseHook, but got {type(hook)}'
else:
hooks = []
self.hooks = hooks
@ -316,9 +305,7 @@ class Trainer:
f"Using {hook.__class__.__name__} for training, priority = {hook.priority}",
ranks=[0],
)
self._logger.info(
"Lower value means higher priority for calling hook function",
ranks=[0])
self._logger.info("Lower value means higher priority for calling hook function", ranks=[0])
self._call_hooks("after_hook_is_attached")
self._engine.train()
@ -360,11 +347,11 @@ class Trainer:
self._call_timer("reset", "Train-epoch")
def evaluate(
self,
test_dataloader: DataLoader,
hooks: List[BaseHook] = None,
display_progress: bool = False,
return_output_label: bool = True,
self,
test_dataloader: DataLoader,
hooks: List[BaseHook] = None,
display_progress: bool = False,
return_output_label: bool = True,
):
"""Evaluates the model with testing data.
@ -381,9 +368,7 @@ class Trainer:
# reset hooks
self._reset_states()
if hooks is not None:
assert isinstance(
hooks, list
), f"expected argument hooks be to list, but got {type(hooks)}"
assert isinstance(hooks, list), f"expected argument hooks be to list, but got {type(hooks)}"
else:
hooks = []
self.hooks = hooks
@ -394,9 +379,7 @@ class Trainer:
f"Using {hook.__class__.__name__} for training, priority = {hook.priority}",
ranks=[0],
)
self._logger.info(
"Lower value means higher priority for calling hook function",
ranks=[0])
self._logger.info("Lower value means higher priority for calling hook function", ranks=[0])
self._call_hooks("after_hook_is_attached")
# eval
@ -406,7 +389,7 @@ class Trainer:
return_output_label=return_output_label,
)
def predict(self, data: Union[Tensor, List[Tensor]]):
def predict(self, data: Union[Any, List[Any]]):
"""Uses trained model to make a prediction for a tensor or a tensor list.
Args:
@ -416,17 +399,11 @@ class Trainer:
:class:`torch.tensor`: The output of model as the prediction
"""
# predict without labels
if isinstance(data, (list, tuple)):
assert isinstance(data[0], Tensor)
else:
assert isinstance(data, Tensor)
self._engine.eval()
# prepare a list of (data, label) to make it iterable
# for compatibility with schedule
simple_dataloader = [(data, None)]
data_iter = iter(simple_dataloader)
output, _, _ = self.engine.execute_schedule(data_iter,
forward_only=True,
return_loss=False)
output, _, _ = self.engine.execute_schedule(data_iter, forward_only=True, return_loss=False)
return output

Loading…
Cancel
Save