mirror of https://github.com/hpcaitech/ColossalAI
[doc] improved assertion messages in trainer (#873)
parent
7a64fae33a
commit
1c34382678
|
@ -1,13 +1,9 @@
|
||||||
from typing import Union, List
|
from typing import Union, List, Any
|
||||||
from colossalai.context.parallel_mode import ParallelMode
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import Tensor
|
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from colossalai.core import global_context as gpc
|
|
||||||
|
|
||||||
from colossalai.engine import Engine
|
from colossalai.engine import Engine
|
||||||
from colossalai.logging import DistributedLogger
|
from colossalai.logging import DistributedLogger
|
||||||
from colossalai.utils import MultiTimer
|
from colossalai.utils import MultiTimer
|
||||||
|
@ -53,6 +49,7 @@ class Trainer:
|
||||||
`Training with engine and trainer <https://www.colossalai.org/docs/basics/engine_trainer>`_
|
`Training with engine and trainer <https://www.colossalai.org/docs/basics/engine_trainer>`_
|
||||||
and `ColossalAI-Examples <https://github.com/hpcaitech/ColossalAI-Examples/tree/main>`_.
|
and `ColossalAI-Examples <https://github.com/hpcaitech/ColossalAI-Examples/tree/main>`_.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
engine: Engine,
|
engine: Engine,
|
||||||
|
@ -154,8 +151,7 @@ class Trainer:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _should_display_progress(display_progress: bool):
|
def _should_display_progress(display_progress: bool):
|
||||||
"""Only display progress on DP rank 0, TP rank 0 and PP last rank"""
|
"""Only display progress on DP rank 0, TP rank 0 and PP last rank"""
|
||||||
return (display_progress and is_dp_rank_0() and is_tp_rank_0()
|
return (display_progress and is_dp_rank_0() and is_tp_rank_0() and is_no_pp_or_last_stage())
|
||||||
and is_no_pp_or_last_stage())
|
|
||||||
|
|
||||||
def _train_epoch(
|
def _train_epoch(
|
||||||
self,
|
self,
|
||||||
|
@ -189,9 +185,7 @@ class Trainer:
|
||||||
return_output_label=return_output_label,
|
return_output_label=return_output_label,
|
||||||
)
|
)
|
||||||
self.engine.step()
|
self.engine.step()
|
||||||
self._call_timer(action="stop",
|
self._call_timer(action="stop", item="Train-step", keep_in_history=True)
|
||||||
item="Train-step",
|
|
||||||
keep_in_history=True)
|
|
||||||
self._call_hooks("after_train_iter", output=(logits, label, loss))
|
self._call_hooks("after_train_iter", output=(logits, label, loss))
|
||||||
|
|
||||||
self._cur_step += 1
|
self._cur_step += 1
|
||||||
|
@ -204,9 +198,7 @@ class Trainer:
|
||||||
if self._exceed_max_step():
|
if self._exceed_max_step():
|
||||||
break
|
break
|
||||||
|
|
||||||
self._call_timer(action="stop",
|
self._call_timer(action="stop", item="Train-epoch", keep_in_history=True)
|
||||||
item="Train-epoch",
|
|
||||||
keep_in_history=True)
|
|
||||||
self._call_hooks("after_train_epoch")
|
self._call_hooks("after_train_epoch")
|
||||||
self._call_timer(action="reset", item="Train-epoch")
|
self._call_timer(action="reset", item="Train-epoch")
|
||||||
|
|
||||||
|
@ -244,19 +236,14 @@ class Trainer:
|
||||||
return_loss=True,
|
return_loss=True,
|
||||||
return_output_label=return_output_label,
|
return_output_label=return_output_label,
|
||||||
)
|
)
|
||||||
self._call_timer(action="stop",
|
self._call_timer(action="stop", item="Test-step", keep_in_history=True)
|
||||||
item="Test-step",
|
self._call_hooks("after_test_iter", output=(logits, label, loss))
|
||||||
keep_in_history=True)
|
|
||||||
self._call_hooks("after_test_iter",
|
|
||||||
output=(logits, label, loss))
|
|
||||||
|
|
||||||
if display_progress:
|
if display_progress:
|
||||||
if "step_metrics" in self.states:
|
if "step_metrics" in self.states:
|
||||||
progress.set_postfix(**self.states["step_metrics"])
|
progress.set_postfix(**self.states["step_metrics"])
|
||||||
|
|
||||||
self._call_timer(action="stop",
|
self._call_timer(action="stop", item="Test-epoch", keep_in_history=True)
|
||||||
item="Test-epoch",
|
|
||||||
keep_in_history=True)
|
|
||||||
self._call_hooks("after_test_epoch")
|
self._call_hooks("after_test_epoch")
|
||||||
self._call_hooks("after_test")
|
self._call_hooks("after_test")
|
||||||
self._call_timer(action="reset", item="Test-step")
|
self._call_timer(action="reset", item="Test-step")
|
||||||
|
@ -303,9 +290,11 @@ class Trainer:
|
||||||
# reset hooks
|
# reset hooks
|
||||||
self._reset_states()
|
self._reset_states()
|
||||||
if hooks is not None:
|
if hooks is not None:
|
||||||
assert isinstance(
|
assert isinstance(hooks, list), f"expected argument hooks be to list, but got {type(hooks)}"
|
||||||
hooks, list
|
|
||||||
), f"expected argument hooks be to list, but got {type(hooks)}"
|
for hook in hooks:
|
||||||
|
assert isinstance(hook, BaseHook), \
|
||||||
|
f'expected the hook to be of type BaseHook, but got {type(hook)}'
|
||||||
else:
|
else:
|
||||||
hooks = []
|
hooks = []
|
||||||
self.hooks = hooks
|
self.hooks = hooks
|
||||||
|
@ -316,9 +305,7 @@ class Trainer:
|
||||||
f"Using {hook.__class__.__name__} for training, priority = {hook.priority}",
|
f"Using {hook.__class__.__name__} for training, priority = {hook.priority}",
|
||||||
ranks=[0],
|
ranks=[0],
|
||||||
)
|
)
|
||||||
self._logger.info(
|
self._logger.info("Lower value means higher priority for calling hook function", ranks=[0])
|
||||||
"Lower value means higher priority for calling hook function",
|
|
||||||
ranks=[0])
|
|
||||||
self._call_hooks("after_hook_is_attached")
|
self._call_hooks("after_hook_is_attached")
|
||||||
|
|
||||||
self._engine.train()
|
self._engine.train()
|
||||||
|
@ -381,9 +368,7 @@ class Trainer:
|
||||||
# reset hooks
|
# reset hooks
|
||||||
self._reset_states()
|
self._reset_states()
|
||||||
if hooks is not None:
|
if hooks is not None:
|
||||||
assert isinstance(
|
assert isinstance(hooks, list), f"expected argument hooks be to list, but got {type(hooks)}"
|
||||||
hooks, list
|
|
||||||
), f"expected argument hooks be to list, but got {type(hooks)}"
|
|
||||||
else:
|
else:
|
||||||
hooks = []
|
hooks = []
|
||||||
self.hooks = hooks
|
self.hooks = hooks
|
||||||
|
@ -394,9 +379,7 @@ class Trainer:
|
||||||
f"Using {hook.__class__.__name__} for training, priority = {hook.priority}",
|
f"Using {hook.__class__.__name__} for training, priority = {hook.priority}",
|
||||||
ranks=[0],
|
ranks=[0],
|
||||||
)
|
)
|
||||||
self._logger.info(
|
self._logger.info("Lower value means higher priority for calling hook function", ranks=[0])
|
||||||
"Lower value means higher priority for calling hook function",
|
|
||||||
ranks=[0])
|
|
||||||
self._call_hooks("after_hook_is_attached")
|
self._call_hooks("after_hook_is_attached")
|
||||||
|
|
||||||
# eval
|
# eval
|
||||||
|
@ -406,7 +389,7 @@ class Trainer:
|
||||||
return_output_label=return_output_label,
|
return_output_label=return_output_label,
|
||||||
)
|
)
|
||||||
|
|
||||||
def predict(self, data: Union[Tensor, List[Tensor]]):
|
def predict(self, data: Union[Any, List[Any]]):
|
||||||
"""Uses trained model to make a prediction for a tensor or a tensor list.
|
"""Uses trained model to make a prediction for a tensor or a tensor list.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -416,17 +399,11 @@ class Trainer:
|
||||||
:class:`torch.tensor`: The output of model as the prediction
|
:class:`torch.tensor`: The output of model as the prediction
|
||||||
"""
|
"""
|
||||||
# predict without labels
|
# predict without labels
|
||||||
if isinstance(data, (list, tuple)):
|
|
||||||
assert isinstance(data[0], Tensor)
|
|
||||||
else:
|
|
||||||
assert isinstance(data, Tensor)
|
|
||||||
self._engine.eval()
|
self._engine.eval()
|
||||||
|
|
||||||
# prepare a list of (data, label) to make it iterable
|
# prepare a list of (data, label) to make it iterable
|
||||||
# for compatibility with schedule
|
# for compatibility with schedule
|
||||||
simple_dataloader = [(data, None)]
|
simple_dataloader = [(data, None)]
|
||||||
data_iter = iter(simple_dataloader)
|
data_iter = iter(simple_dataloader)
|
||||||
output, _, _ = self.engine.execute_schedule(data_iter,
|
output, _, _ = self.engine.execute_schedule(data_iter, forward_only=True, return_loss=False)
|
||||||
forward_only=True,
|
|
||||||
return_loss=False)
|
|
||||||
return output
|
return output
|
||||||
|
|
Loading…
Reference in New Issue