diff --git a/colossalai/trainer/_trainer.py b/colossalai/trainer/_trainer.py index a89f2f3aa..60bbc4eee 100644 --- a/colossalai/trainer/_trainer.py +++ b/colossalai/trainer/_trainer.py @@ -1,13 +1,9 @@ -from typing import Union, List -from colossalai.context.parallel_mode import ParallelMode +from typing import Union, List, Any import torch -from torch import Tensor from torch.utils.data import DataLoader from tqdm import tqdm -from colossalai.core import global_context as gpc - from colossalai.engine import Engine from colossalai.logging import DistributedLogger from colossalai.utils import MultiTimer @@ -53,11 +49,12 @@ class Trainer: `Training with engine and trainer `_ and `ColossalAI-Examples `_. """ + def __init__( - self, - engine: Engine, - timer: MultiTimer = None, - logger: DistributedLogger = None, + self, + engine: Engine, + timer: MultiTimer = None, + logger: DistributedLogger = None, ): # training-ralated params self._engine = engine @@ -154,15 +151,14 @@ class Trainer: @staticmethod def _should_display_progress(display_progress: bool): """Only display progress on DP rank 0, TP rank 0 and PP last rank""" - return (display_progress and is_dp_rank_0() and is_tp_rank_0() - and is_no_pp_or_last_stage()) + return (display_progress and is_dp_rank_0() and is_tp_rank_0() and is_no_pp_or_last_stage()) def _train_epoch( - self, - train_dataloader: DataLoader, - epoch: int = None, - display_progress: bool = False, - return_output_label: bool = True, + self, + train_dataloader: DataLoader, + epoch: int = None, + display_progress: bool = False, + return_output_label: bool = True, ): # set training state self._engine.train() @@ -189,9 +185,7 @@ class Trainer: return_output_label=return_output_label, ) self.engine.step() - self._call_timer(action="stop", - item="Train-step", - keep_in_history=True) + self._call_timer(action="stop", item="Train-step", keep_in_history=True) self._call_hooks("after_train_iter", output=(logits, label, loss)) self._cur_step += 1 @@ -204,18 +198,16 @@ class Trainer: if self._exceed_max_step(): break - self._call_timer(action="stop", - item="Train-epoch", - keep_in_history=True) + self._call_timer(action="stop", item="Train-epoch", keep_in_history=True) self._call_hooks("after_train_epoch") self._call_timer(action="reset", item="Train-epoch") def _eval( - self, - test_dataloader: DataLoader, - epoch: int = None, - display_progress: bool = False, - return_output_label: bool = True, + self, + test_dataloader: DataLoader, + epoch: int = None, + display_progress: bool = False, + return_output_label: bool = True, ): # switch engine status self._engine.eval() @@ -244,19 +236,14 @@ class Trainer: return_loss=True, return_output_label=return_output_label, ) - self._call_timer(action="stop", - item="Test-step", - keep_in_history=True) - self._call_hooks("after_test_iter", - output=(logits, label, loss)) + self._call_timer(action="stop", item="Test-step", keep_in_history=True) + self._call_hooks("after_test_iter", output=(logits, label, loss)) if display_progress: if "step_metrics" in self.states: progress.set_postfix(**self.states["step_metrics"]) - self._call_timer(action="stop", - item="Test-epoch", - keep_in_history=True) + self._call_timer(action="stop", item="Test-epoch", keep_in_history=True) self._call_hooks("after_test_epoch") self._call_hooks("after_test") self._call_timer(action="reset", item="Test-step") @@ -266,15 +253,15 @@ class Trainer: return self._max_steps is not None and self._cur_step >= self._max_steps def fit( - self, - train_dataloader: DataLoader, - epochs: int, - max_steps: int = None, - test_dataloader: DataLoader = None, - test_interval: int = 1, - hooks: List[BaseHook] = None, - display_progress: bool = False, - return_output_label: bool = True, + self, + train_dataloader: DataLoader, + epochs: int, + max_steps: int = None, + test_dataloader: DataLoader = None, + test_interval: int = 1, + hooks: List[BaseHook] = None, + display_progress: bool = False, + return_output_label: bool = True, ): r"""Trains the model to fit training data. @@ -303,9 +290,11 @@ class Trainer: # reset hooks self._reset_states() if hooks is not None: - assert isinstance( - hooks, list - ), f"expected argument hooks be to list, but got {type(hooks)}" + assert isinstance(hooks, list), f"expected argument hooks be to list, but got {type(hooks)}" + + for hook in hooks: + assert isinstance(hook, BaseHook), \ + f'expected the hook to be of type BaseHook, but got {type(hook)}' else: hooks = [] self.hooks = hooks @@ -316,9 +305,7 @@ class Trainer: f"Using {hook.__class__.__name__} for training, priority = {hook.priority}", ranks=[0], ) - self._logger.info( - "Lower value means higher priority for calling hook function", - ranks=[0]) + self._logger.info("Lower value means higher priority for calling hook function", ranks=[0]) self._call_hooks("after_hook_is_attached") self._engine.train() @@ -360,11 +347,11 @@ class Trainer: self._call_timer("reset", "Train-epoch") def evaluate( - self, - test_dataloader: DataLoader, - hooks: List[BaseHook] = None, - display_progress: bool = False, - return_output_label: bool = True, + self, + test_dataloader: DataLoader, + hooks: List[BaseHook] = None, + display_progress: bool = False, + return_output_label: bool = True, ): """Evaluates the model with testing data. @@ -381,9 +368,7 @@ class Trainer: # reset hooks self._reset_states() if hooks is not None: - assert isinstance( - hooks, list - ), f"expected argument hooks be to list, but got {type(hooks)}" + assert isinstance(hooks, list), f"expected argument hooks be to list, but got {type(hooks)}" else: hooks = [] self.hooks = hooks @@ -394,9 +379,7 @@ class Trainer: f"Using {hook.__class__.__name__} for training, priority = {hook.priority}", ranks=[0], ) - self._logger.info( - "Lower value means higher priority for calling hook function", - ranks=[0]) + self._logger.info("Lower value means higher priority for calling hook function", ranks=[0]) self._call_hooks("after_hook_is_attached") # eval @@ -406,7 +389,7 @@ class Trainer: return_output_label=return_output_label, ) - def predict(self, data: Union[Tensor, List[Tensor]]): + def predict(self, data: Union[Any, List[Any]]): """Uses trained model to make a prediction for a tensor or a tensor list. Args: @@ -416,17 +399,11 @@ class Trainer: :class:`torch.tensor`: The output of model as the prediction """ # predict without labels - if isinstance(data, (list, tuple)): - assert isinstance(data[0], Tensor) - else: - assert isinstance(data, Tensor) self._engine.eval() # prepare a list of (data, label) to make it iterable # for compatibility with schedule simple_dataloader = [(data, None)] data_iter = iter(simple_dataloader) - output, _, _ = self.engine.execute_schedule(data_iter, - forward_only=True, - return_loss=False) + output, _, _ = self.engine.execute_schedule(data_iter, forward_only=True, return_loss=False) return output