|
|
|
@ -1,13 +1,9 @@
|
|
|
|
|
from typing import Union, List
|
|
|
|
|
from colossalai.context.parallel_mode import ParallelMode
|
|
|
|
|
from typing import Union, List, Any
|
|
|
|
|
|
|
|
|
|
import torch
|
|
|
|
|
from torch import Tensor
|
|
|
|
|
from torch.utils.data import DataLoader
|
|
|
|
|
from tqdm import tqdm
|
|
|
|
|
|
|
|
|
|
from colossalai.core import global_context as gpc
|
|
|
|
|
|
|
|
|
|
from colossalai.engine import Engine
|
|
|
|
|
from colossalai.logging import DistributedLogger
|
|
|
|
|
from colossalai.utils import MultiTimer
|
|
|
|
@ -53,11 +49,12 @@ class Trainer:
|
|
|
|
|
`Training with engine and trainer <https://www.colossalai.org/docs/basics/engine_trainer>`_
|
|
|
|
|
and `ColossalAI-Examples <https://github.com/hpcaitech/ColossalAI-Examples/tree/main>`_.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
|
self,
|
|
|
|
|
engine: Engine,
|
|
|
|
|
timer: MultiTimer = None,
|
|
|
|
|
logger: DistributedLogger = None,
|
|
|
|
|
self,
|
|
|
|
|
engine: Engine,
|
|
|
|
|
timer: MultiTimer = None,
|
|
|
|
|
logger: DistributedLogger = None,
|
|
|
|
|
):
|
|
|
|
|
# training-ralated params
|
|
|
|
|
self._engine = engine
|
|
|
|
@ -154,15 +151,14 @@ class Trainer:
|
|
|
|
|
@staticmethod
|
|
|
|
|
def _should_display_progress(display_progress: bool):
|
|
|
|
|
"""Only display progress on DP rank 0, TP rank 0 and PP last rank"""
|
|
|
|
|
return (display_progress and is_dp_rank_0() and is_tp_rank_0()
|
|
|
|
|
and is_no_pp_or_last_stage())
|
|
|
|
|
return (display_progress and is_dp_rank_0() and is_tp_rank_0() and is_no_pp_or_last_stage())
|
|
|
|
|
|
|
|
|
|
def _train_epoch(
|
|
|
|
|
self,
|
|
|
|
|
train_dataloader: DataLoader,
|
|
|
|
|
epoch: int = None,
|
|
|
|
|
display_progress: bool = False,
|
|
|
|
|
return_output_label: bool = True,
|
|
|
|
|
self,
|
|
|
|
|
train_dataloader: DataLoader,
|
|
|
|
|
epoch: int = None,
|
|
|
|
|
display_progress: bool = False,
|
|
|
|
|
return_output_label: bool = True,
|
|
|
|
|
):
|
|
|
|
|
# set training state
|
|
|
|
|
self._engine.train()
|
|
|
|
@ -189,9 +185,7 @@ class Trainer:
|
|
|
|
|
return_output_label=return_output_label,
|
|
|
|
|
)
|
|
|
|
|
self.engine.step()
|
|
|
|
|
self._call_timer(action="stop",
|
|
|
|
|
item="Train-step",
|
|
|
|
|
keep_in_history=True)
|
|
|
|
|
self._call_timer(action="stop", item="Train-step", keep_in_history=True)
|
|
|
|
|
self._call_hooks("after_train_iter", output=(logits, label, loss))
|
|
|
|
|
|
|
|
|
|
self._cur_step += 1
|
|
|
|
@ -204,18 +198,16 @@ class Trainer:
|
|
|
|
|
if self._exceed_max_step():
|
|
|
|
|
break
|
|
|
|
|
|
|
|
|
|
self._call_timer(action="stop",
|
|
|
|
|
item="Train-epoch",
|
|
|
|
|
keep_in_history=True)
|
|
|
|
|
self._call_timer(action="stop", item="Train-epoch", keep_in_history=True)
|
|
|
|
|
self._call_hooks("after_train_epoch")
|
|
|
|
|
self._call_timer(action="reset", item="Train-epoch")
|
|
|
|
|
|
|
|
|
|
def _eval(
|
|
|
|
|
self,
|
|
|
|
|
test_dataloader: DataLoader,
|
|
|
|
|
epoch: int = None,
|
|
|
|
|
display_progress: bool = False,
|
|
|
|
|
return_output_label: bool = True,
|
|
|
|
|
self,
|
|
|
|
|
test_dataloader: DataLoader,
|
|
|
|
|
epoch: int = None,
|
|
|
|
|
display_progress: bool = False,
|
|
|
|
|
return_output_label: bool = True,
|
|
|
|
|
):
|
|
|
|
|
# switch engine status
|
|
|
|
|
self._engine.eval()
|
|
|
|
@ -244,19 +236,14 @@ class Trainer:
|
|
|
|
|
return_loss=True,
|
|
|
|
|
return_output_label=return_output_label,
|
|
|
|
|
)
|
|
|
|
|
self._call_timer(action="stop",
|
|
|
|
|
item="Test-step",
|
|
|
|
|
keep_in_history=True)
|
|
|
|
|
self._call_hooks("after_test_iter",
|
|
|
|
|
output=(logits, label, loss))
|
|
|
|
|
self._call_timer(action="stop", item="Test-step", keep_in_history=True)
|
|
|
|
|
self._call_hooks("after_test_iter", output=(logits, label, loss))
|
|
|
|
|
|
|
|
|
|
if display_progress:
|
|
|
|
|
if "step_metrics" in self.states:
|
|
|
|
|
progress.set_postfix(**self.states["step_metrics"])
|
|
|
|
|
|
|
|
|
|
self._call_timer(action="stop",
|
|
|
|
|
item="Test-epoch",
|
|
|
|
|
keep_in_history=True)
|
|
|
|
|
self._call_timer(action="stop", item="Test-epoch", keep_in_history=True)
|
|
|
|
|
self._call_hooks("after_test_epoch")
|
|
|
|
|
self._call_hooks("after_test")
|
|
|
|
|
self._call_timer(action="reset", item="Test-step")
|
|
|
|
@ -266,15 +253,15 @@ class Trainer:
|
|
|
|
|
return self._max_steps is not None and self._cur_step >= self._max_steps
|
|
|
|
|
|
|
|
|
|
def fit(
|
|
|
|
|
self,
|
|
|
|
|
train_dataloader: DataLoader,
|
|
|
|
|
epochs: int,
|
|
|
|
|
max_steps: int = None,
|
|
|
|
|
test_dataloader: DataLoader = None,
|
|
|
|
|
test_interval: int = 1,
|
|
|
|
|
hooks: List[BaseHook] = None,
|
|
|
|
|
display_progress: bool = False,
|
|
|
|
|
return_output_label: bool = True,
|
|
|
|
|
self,
|
|
|
|
|
train_dataloader: DataLoader,
|
|
|
|
|
epochs: int,
|
|
|
|
|
max_steps: int = None,
|
|
|
|
|
test_dataloader: DataLoader = None,
|
|
|
|
|
test_interval: int = 1,
|
|
|
|
|
hooks: List[BaseHook] = None,
|
|
|
|
|
display_progress: bool = False,
|
|
|
|
|
return_output_label: bool = True,
|
|
|
|
|
):
|
|
|
|
|
r"""Trains the model to fit training data.
|
|
|
|
|
|
|
|
|
@ -303,9 +290,11 @@ class Trainer:
|
|
|
|
|
# reset hooks
|
|
|
|
|
self._reset_states()
|
|
|
|
|
if hooks is not None:
|
|
|
|
|
assert isinstance(
|
|
|
|
|
hooks, list
|
|
|
|
|
), f"expected argument hooks be to list, but got {type(hooks)}"
|
|
|
|
|
assert isinstance(hooks, list), f"expected argument hooks be to list, but got {type(hooks)}"
|
|
|
|
|
|
|
|
|
|
for hook in hooks:
|
|
|
|
|
assert isinstance(hook, BaseHook), \
|
|
|
|
|
f'expected the hook to be of type BaseHook, but got {type(hook)}'
|
|
|
|
|
else:
|
|
|
|
|
hooks = []
|
|
|
|
|
self.hooks = hooks
|
|
|
|
@ -316,9 +305,7 @@ class Trainer:
|
|
|
|
|
f"Using {hook.__class__.__name__} for training, priority = {hook.priority}",
|
|
|
|
|
ranks=[0],
|
|
|
|
|
)
|
|
|
|
|
self._logger.info(
|
|
|
|
|
"Lower value means higher priority for calling hook function",
|
|
|
|
|
ranks=[0])
|
|
|
|
|
self._logger.info("Lower value means higher priority for calling hook function", ranks=[0])
|
|
|
|
|
self._call_hooks("after_hook_is_attached")
|
|
|
|
|
|
|
|
|
|
self._engine.train()
|
|
|
|
@ -360,11 +347,11 @@ class Trainer:
|
|
|
|
|
self._call_timer("reset", "Train-epoch")
|
|
|
|
|
|
|
|
|
|
def evaluate(
|
|
|
|
|
self,
|
|
|
|
|
test_dataloader: DataLoader,
|
|
|
|
|
hooks: List[BaseHook] = None,
|
|
|
|
|
display_progress: bool = False,
|
|
|
|
|
return_output_label: bool = True,
|
|
|
|
|
self,
|
|
|
|
|
test_dataloader: DataLoader,
|
|
|
|
|
hooks: List[BaseHook] = None,
|
|
|
|
|
display_progress: bool = False,
|
|
|
|
|
return_output_label: bool = True,
|
|
|
|
|
):
|
|
|
|
|
"""Evaluates the model with testing data.
|
|
|
|
|
|
|
|
|
@ -381,9 +368,7 @@ class Trainer:
|
|
|
|
|
# reset hooks
|
|
|
|
|
self._reset_states()
|
|
|
|
|
if hooks is not None:
|
|
|
|
|
assert isinstance(
|
|
|
|
|
hooks, list
|
|
|
|
|
), f"expected argument hooks be to list, but got {type(hooks)}"
|
|
|
|
|
assert isinstance(hooks, list), f"expected argument hooks be to list, but got {type(hooks)}"
|
|
|
|
|
else:
|
|
|
|
|
hooks = []
|
|
|
|
|
self.hooks = hooks
|
|
|
|
@ -394,9 +379,7 @@ class Trainer:
|
|
|
|
|
f"Using {hook.__class__.__name__} for training, priority = {hook.priority}",
|
|
|
|
|
ranks=[0],
|
|
|
|
|
)
|
|
|
|
|
self._logger.info(
|
|
|
|
|
"Lower value means higher priority for calling hook function",
|
|
|
|
|
ranks=[0])
|
|
|
|
|
self._logger.info("Lower value means higher priority for calling hook function", ranks=[0])
|
|
|
|
|
self._call_hooks("after_hook_is_attached")
|
|
|
|
|
|
|
|
|
|
# eval
|
|
|
|
@ -406,7 +389,7 @@ class Trainer:
|
|
|
|
|
return_output_label=return_output_label,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def predict(self, data: Union[Tensor, List[Tensor]]):
|
|
|
|
|
def predict(self, data: Union[Any, List[Any]]):
|
|
|
|
|
"""Uses trained model to make a prediction for a tensor or a tensor list.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
@ -416,17 +399,11 @@ class Trainer:
|
|
|
|
|
:class:`torch.tensor`: The output of model as the prediction
|
|
|
|
|
"""
|
|
|
|
|
# predict without labels
|
|
|
|
|
if isinstance(data, (list, tuple)):
|
|
|
|
|
assert isinstance(data[0], Tensor)
|
|
|
|
|
else:
|
|
|
|
|
assert isinstance(data, Tensor)
|
|
|
|
|
self._engine.eval()
|
|
|
|
|
|
|
|
|
|
# prepare a list of (data, label) to make it iterable
|
|
|
|
|
# for compatibility with schedule
|
|
|
|
|
simple_dataloader = [(data, None)]
|
|
|
|
|
data_iter = iter(simple_dataloader)
|
|
|
|
|
output, _, _ = self.engine.execute_schedule(data_iter,
|
|
|
|
|
forward_only=True,
|
|
|
|
|
return_loss=False)
|
|
|
|
|
output, _, _ = self.engine.execute_schedule(data_iter, forward_only=True, return_loss=False)
|
|
|
|
|
return output
|
|
|
|
|