mirror of https://github.com/hpcaitech/ColossalAI
[doc] improved assertion messages in trainer (#873)
parent
7a64fae33a
commit
1c34382678
|
@ -1,13 +1,9 @@
|
|||
from typing import Union, List
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from typing import Union, List, Any
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch.utils.data import DataLoader
|
||||
from tqdm import tqdm
|
||||
|
||||
from colossalai.core import global_context as gpc
|
||||
|
||||
from colossalai.engine import Engine
|
||||
from colossalai.logging import DistributedLogger
|
||||
from colossalai.utils import MultiTimer
|
||||
|
@ -53,6 +49,7 @@ class Trainer:
|
|||
`Training with engine and trainer <https://www.colossalai.org/docs/basics/engine_trainer>`_
|
||||
and `ColossalAI-Examples <https://github.com/hpcaitech/ColossalAI-Examples/tree/main>`_.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
engine: Engine,
|
||||
|
@ -154,8 +151,7 @@ class Trainer:
|
|||
@staticmethod
|
||||
def _should_display_progress(display_progress: bool):
|
||||
"""Only display progress on DP rank 0, TP rank 0 and PP last rank"""
|
||||
return (display_progress and is_dp_rank_0() and is_tp_rank_0()
|
||||
and is_no_pp_or_last_stage())
|
||||
return (display_progress and is_dp_rank_0() and is_tp_rank_0() and is_no_pp_or_last_stage())
|
||||
|
||||
def _train_epoch(
|
||||
self,
|
||||
|
@ -189,9 +185,7 @@ class Trainer:
|
|||
return_output_label=return_output_label,
|
||||
)
|
||||
self.engine.step()
|
||||
self._call_timer(action="stop",
|
||||
item="Train-step",
|
||||
keep_in_history=True)
|
||||
self._call_timer(action="stop", item="Train-step", keep_in_history=True)
|
||||
self._call_hooks("after_train_iter", output=(logits, label, loss))
|
||||
|
||||
self._cur_step += 1
|
||||
|
@ -204,9 +198,7 @@ class Trainer:
|
|||
if self._exceed_max_step():
|
||||
break
|
||||
|
||||
self._call_timer(action="stop",
|
||||
item="Train-epoch",
|
||||
keep_in_history=True)
|
||||
self._call_timer(action="stop", item="Train-epoch", keep_in_history=True)
|
||||
self._call_hooks("after_train_epoch")
|
||||
self._call_timer(action="reset", item="Train-epoch")
|
||||
|
||||
|
@ -244,19 +236,14 @@ class Trainer:
|
|||
return_loss=True,
|
||||
return_output_label=return_output_label,
|
||||
)
|
||||
self._call_timer(action="stop",
|
||||
item="Test-step",
|
||||
keep_in_history=True)
|
||||
self._call_hooks("after_test_iter",
|
||||
output=(logits, label, loss))
|
||||
self._call_timer(action="stop", item="Test-step", keep_in_history=True)
|
||||
self._call_hooks("after_test_iter", output=(logits, label, loss))
|
||||
|
||||
if display_progress:
|
||||
if "step_metrics" in self.states:
|
||||
progress.set_postfix(**self.states["step_metrics"])
|
||||
|
||||
self._call_timer(action="stop",
|
||||
item="Test-epoch",
|
||||
keep_in_history=True)
|
||||
self._call_timer(action="stop", item="Test-epoch", keep_in_history=True)
|
||||
self._call_hooks("after_test_epoch")
|
||||
self._call_hooks("after_test")
|
||||
self._call_timer(action="reset", item="Test-step")
|
||||
|
@ -303,9 +290,11 @@ class Trainer:
|
|||
# reset hooks
|
||||
self._reset_states()
|
||||
if hooks is not None:
|
||||
assert isinstance(
|
||||
hooks, list
|
||||
), f"expected argument hooks be to list, but got {type(hooks)}"
|
||||
assert isinstance(hooks, list), f"expected argument hooks be to list, but got {type(hooks)}"
|
||||
|
||||
for hook in hooks:
|
||||
assert isinstance(hook, BaseHook), \
|
||||
f'expected the hook to be of type BaseHook, but got {type(hook)}'
|
||||
else:
|
||||
hooks = []
|
||||
self.hooks = hooks
|
||||
|
@ -316,9 +305,7 @@ class Trainer:
|
|||
f"Using {hook.__class__.__name__} for training, priority = {hook.priority}",
|
||||
ranks=[0],
|
||||
)
|
||||
self._logger.info(
|
||||
"Lower value means higher priority for calling hook function",
|
||||
ranks=[0])
|
||||
self._logger.info("Lower value means higher priority for calling hook function", ranks=[0])
|
||||
self._call_hooks("after_hook_is_attached")
|
||||
|
||||
self._engine.train()
|
||||
|
@ -381,9 +368,7 @@ class Trainer:
|
|||
# reset hooks
|
||||
self._reset_states()
|
||||
if hooks is not None:
|
||||
assert isinstance(
|
||||
hooks, list
|
||||
), f"expected argument hooks be to list, but got {type(hooks)}"
|
||||
assert isinstance(hooks, list), f"expected argument hooks be to list, but got {type(hooks)}"
|
||||
else:
|
||||
hooks = []
|
||||
self.hooks = hooks
|
||||
|
@ -394,9 +379,7 @@ class Trainer:
|
|||
f"Using {hook.__class__.__name__} for training, priority = {hook.priority}",
|
||||
ranks=[0],
|
||||
)
|
||||
self._logger.info(
|
||||
"Lower value means higher priority for calling hook function",
|
||||
ranks=[0])
|
||||
self._logger.info("Lower value means higher priority for calling hook function", ranks=[0])
|
||||
self._call_hooks("after_hook_is_attached")
|
||||
|
||||
# eval
|
||||
|
@ -406,7 +389,7 @@ class Trainer:
|
|||
return_output_label=return_output_label,
|
||||
)
|
||||
|
||||
def predict(self, data: Union[Tensor, List[Tensor]]):
|
||||
def predict(self, data: Union[Any, List[Any]]):
|
||||
"""Uses trained model to make a prediction for a tensor or a tensor list.
|
||||
|
||||
Args:
|
||||
|
@ -416,17 +399,11 @@ class Trainer:
|
|||
:class:`torch.tensor`: The output of model as the prediction
|
||||
"""
|
||||
# predict without labels
|
||||
if isinstance(data, (list, tuple)):
|
||||
assert isinstance(data[0], Tensor)
|
||||
else:
|
||||
assert isinstance(data, Tensor)
|
||||
self._engine.eval()
|
||||
|
||||
# prepare a list of (data, label) to make it iterable
|
||||
# for compatibility with schedule
|
||||
simple_dataloader = [(data, None)]
|
||||
data_iter = iter(simple_dataloader)
|
||||
output, _, _ = self.engine.execute_schedule(data_iter,
|
||||
forward_only=True,
|
||||
return_loss=False)
|
||||
output, _, _ = self.engine.execute_schedule(data_iter, forward_only=True, return_loss=False)
|
||||
return output
|
||||
|
|
Loading…
Reference in New Issue