mirror of https://github.com/hpcaitech/ColossalAI
[doc] improved assertion messages in trainer (#873)
parent
7a64fae33a
commit
1c34382678
|
@ -1,13 +1,9 @@
|
|||
from typing import Union, List
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from typing import Union, List, Any
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch.utils.data import DataLoader
|
||||
from tqdm import tqdm
|
||||
|
||||
from colossalai.core import global_context as gpc
|
||||
|
||||
from colossalai.engine import Engine
|
||||
from colossalai.logging import DistributedLogger
|
||||
from colossalai.utils import MultiTimer
|
||||
|
@ -53,11 +49,12 @@ class Trainer:
|
|||
`Training with engine and trainer <https://www.colossalai.org/docs/basics/engine_trainer>`_
|
||||
and `ColossalAI-Examples <https://github.com/hpcaitech/ColossalAI-Examples/tree/main>`_.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
engine: Engine,
|
||||
timer: MultiTimer = None,
|
||||
logger: DistributedLogger = None,
|
||||
self,
|
||||
engine: Engine,
|
||||
timer: MultiTimer = None,
|
||||
logger: DistributedLogger = None,
|
||||
):
|
||||
# training-ralated params
|
||||
self._engine = engine
|
||||
|
@ -154,15 +151,14 @@ class Trainer:
|
|||
@staticmethod
|
||||
def _should_display_progress(display_progress: bool):
|
||||
"""Only display progress on DP rank 0, TP rank 0 and PP last rank"""
|
||||
return (display_progress and is_dp_rank_0() and is_tp_rank_0()
|
||||
and is_no_pp_or_last_stage())
|
||||
return (display_progress and is_dp_rank_0() and is_tp_rank_0() and is_no_pp_or_last_stage())
|
||||
|
||||
def _train_epoch(
|
||||
self,
|
||||
train_dataloader: DataLoader,
|
||||
epoch: int = None,
|
||||
display_progress: bool = False,
|
||||
return_output_label: bool = True,
|
||||
self,
|
||||
train_dataloader: DataLoader,
|
||||
epoch: int = None,
|
||||
display_progress: bool = False,
|
||||
return_output_label: bool = True,
|
||||
):
|
||||
# set training state
|
||||
self._engine.train()
|
||||
|
@ -189,9 +185,7 @@ class Trainer:
|
|||
return_output_label=return_output_label,
|
||||
)
|
||||
self.engine.step()
|
||||
self._call_timer(action="stop",
|
||||
item="Train-step",
|
||||
keep_in_history=True)
|
||||
self._call_timer(action="stop", item="Train-step", keep_in_history=True)
|
||||
self._call_hooks("after_train_iter", output=(logits, label, loss))
|
||||
|
||||
self._cur_step += 1
|
||||
|
@ -204,18 +198,16 @@ class Trainer:
|
|||
if self._exceed_max_step():
|
||||
break
|
||||
|
||||
self._call_timer(action="stop",
|
||||
item="Train-epoch",
|
||||
keep_in_history=True)
|
||||
self._call_timer(action="stop", item="Train-epoch", keep_in_history=True)
|
||||
self._call_hooks("after_train_epoch")
|
||||
self._call_timer(action="reset", item="Train-epoch")
|
||||
|
||||
def _eval(
|
||||
self,
|
||||
test_dataloader: DataLoader,
|
||||
epoch: int = None,
|
||||
display_progress: bool = False,
|
||||
return_output_label: bool = True,
|
||||
self,
|
||||
test_dataloader: DataLoader,
|
||||
epoch: int = None,
|
||||
display_progress: bool = False,
|
||||
return_output_label: bool = True,
|
||||
):
|
||||
# switch engine status
|
||||
self._engine.eval()
|
||||
|
@ -244,19 +236,14 @@ class Trainer:
|
|||
return_loss=True,
|
||||
return_output_label=return_output_label,
|
||||
)
|
||||
self._call_timer(action="stop",
|
||||
item="Test-step",
|
||||
keep_in_history=True)
|
||||
self._call_hooks("after_test_iter",
|
||||
output=(logits, label, loss))
|
||||
self._call_timer(action="stop", item="Test-step", keep_in_history=True)
|
||||
self._call_hooks("after_test_iter", output=(logits, label, loss))
|
||||
|
||||
if display_progress:
|
||||
if "step_metrics" in self.states:
|
||||
progress.set_postfix(**self.states["step_metrics"])
|
||||
|
||||
self._call_timer(action="stop",
|
||||
item="Test-epoch",
|
||||
keep_in_history=True)
|
||||
self._call_timer(action="stop", item="Test-epoch", keep_in_history=True)
|
||||
self._call_hooks("after_test_epoch")
|
||||
self._call_hooks("after_test")
|
||||
self._call_timer(action="reset", item="Test-step")
|
||||
|
@ -266,15 +253,15 @@ class Trainer:
|
|||
return self._max_steps is not None and self._cur_step >= self._max_steps
|
||||
|
||||
def fit(
|
||||
self,
|
||||
train_dataloader: DataLoader,
|
||||
epochs: int,
|
||||
max_steps: int = None,
|
||||
test_dataloader: DataLoader = None,
|
||||
test_interval: int = 1,
|
||||
hooks: List[BaseHook] = None,
|
||||
display_progress: bool = False,
|
||||
return_output_label: bool = True,
|
||||
self,
|
||||
train_dataloader: DataLoader,
|
||||
epochs: int,
|
||||
max_steps: int = None,
|
||||
test_dataloader: DataLoader = None,
|
||||
test_interval: int = 1,
|
||||
hooks: List[BaseHook] = None,
|
||||
display_progress: bool = False,
|
||||
return_output_label: bool = True,
|
||||
):
|
||||
r"""Trains the model to fit training data.
|
||||
|
||||
|
@ -303,9 +290,11 @@ class Trainer:
|
|||
# reset hooks
|
||||
self._reset_states()
|
||||
if hooks is not None:
|
||||
assert isinstance(
|
||||
hooks, list
|
||||
), f"expected argument hooks be to list, but got {type(hooks)}"
|
||||
assert isinstance(hooks, list), f"expected argument hooks be to list, but got {type(hooks)}"
|
||||
|
||||
for hook in hooks:
|
||||
assert isinstance(hook, BaseHook), \
|
||||
f'expected the hook to be of type BaseHook, but got {type(hook)}'
|
||||
else:
|
||||
hooks = []
|
||||
self.hooks = hooks
|
||||
|
@ -316,9 +305,7 @@ class Trainer:
|
|||
f"Using {hook.__class__.__name__} for training, priority = {hook.priority}",
|
||||
ranks=[0],
|
||||
)
|
||||
self._logger.info(
|
||||
"Lower value means higher priority for calling hook function",
|
||||
ranks=[0])
|
||||
self._logger.info("Lower value means higher priority for calling hook function", ranks=[0])
|
||||
self._call_hooks("after_hook_is_attached")
|
||||
|
||||
self._engine.train()
|
||||
|
@ -360,11 +347,11 @@ class Trainer:
|
|||
self._call_timer("reset", "Train-epoch")
|
||||
|
||||
def evaluate(
|
||||
self,
|
||||
test_dataloader: DataLoader,
|
||||
hooks: List[BaseHook] = None,
|
||||
display_progress: bool = False,
|
||||
return_output_label: bool = True,
|
||||
self,
|
||||
test_dataloader: DataLoader,
|
||||
hooks: List[BaseHook] = None,
|
||||
display_progress: bool = False,
|
||||
return_output_label: bool = True,
|
||||
):
|
||||
"""Evaluates the model with testing data.
|
||||
|
||||
|
@ -381,9 +368,7 @@ class Trainer:
|
|||
# reset hooks
|
||||
self._reset_states()
|
||||
if hooks is not None:
|
||||
assert isinstance(
|
||||
hooks, list
|
||||
), f"expected argument hooks be to list, but got {type(hooks)}"
|
||||
assert isinstance(hooks, list), f"expected argument hooks be to list, but got {type(hooks)}"
|
||||
else:
|
||||
hooks = []
|
||||
self.hooks = hooks
|
||||
|
@ -394,9 +379,7 @@ class Trainer:
|
|||
f"Using {hook.__class__.__name__} for training, priority = {hook.priority}",
|
||||
ranks=[0],
|
||||
)
|
||||
self._logger.info(
|
||||
"Lower value means higher priority for calling hook function",
|
||||
ranks=[0])
|
||||
self._logger.info("Lower value means higher priority for calling hook function", ranks=[0])
|
||||
self._call_hooks("after_hook_is_attached")
|
||||
|
||||
# eval
|
||||
|
@ -406,7 +389,7 @@ class Trainer:
|
|||
return_output_label=return_output_label,
|
||||
)
|
||||
|
||||
def predict(self, data: Union[Tensor, List[Tensor]]):
|
||||
def predict(self, data: Union[Any, List[Any]]):
|
||||
"""Uses trained model to make a prediction for a tensor or a tensor list.
|
||||
|
||||
Args:
|
||||
|
@ -416,17 +399,11 @@ class Trainer:
|
|||
:class:`torch.tensor`: The output of model as the prediction
|
||||
"""
|
||||
# predict without labels
|
||||
if isinstance(data, (list, tuple)):
|
||||
assert isinstance(data[0], Tensor)
|
||||
else:
|
||||
assert isinstance(data, Tensor)
|
||||
self._engine.eval()
|
||||
|
||||
# prepare a list of (data, label) to make it iterable
|
||||
# for compatibility with schedule
|
||||
simple_dataloader = [(data, None)]
|
||||
data_iter = iter(simple_dataloader)
|
||||
output, _, _ = self.engine.execute_schedule(data_iter,
|
||||
forward_only=True,
|
||||
return_loss=False)
|
||||
output, _, _ = self.engine.execute_schedule(data_iter, forward_only=True, return_loss=False)
|
||||
return output
|
||||
|
|
Loading…
Reference in New Issue