|
|
|
@ -1,13 +1,9 @@
|
|
|
|
|
from typing import Union, List |
|
|
|
|
from colossalai.context.parallel_mode import ParallelMode |
|
|
|
|
from typing import Union, List, Any |
|
|
|
|
|
|
|
|
|
import torch |
|
|
|
|
from torch import Tensor |
|
|
|
|
from torch.utils.data import DataLoader |
|
|
|
|
from tqdm import tqdm |
|
|
|
|
|
|
|
|
|
from colossalai.core import global_context as gpc |
|
|
|
|
|
|
|
|
|
from colossalai.engine import Engine |
|
|
|
|
from colossalai.logging import DistributedLogger |
|
|
|
|
from colossalai.utils import MultiTimer |
|
|
|
@ -53,11 +49,12 @@ class Trainer:
|
|
|
|
|
`Training with engine and trainer <https://www.colossalai.org/docs/basics/engine_trainer>`_ |
|
|
|
|
and `ColossalAI-Examples <https://github.com/hpcaitech/ColossalAI-Examples/tree/main>`_. |
|
|
|
|
""" |
|
|
|
|
|
|
|
|
|
def __init__( |
|
|
|
|
self, |
|
|
|
|
engine: Engine, |
|
|
|
|
timer: MultiTimer = None, |
|
|
|
|
logger: DistributedLogger = None, |
|
|
|
|
self, |
|
|
|
|
engine: Engine, |
|
|
|
|
timer: MultiTimer = None, |
|
|
|
|
logger: DistributedLogger = None, |
|
|
|
|
): |
|
|
|
|
# training-ralated params |
|
|
|
|
self._engine = engine |
|
|
|
@ -154,15 +151,14 @@ class Trainer:
|
|
|
|
|
@staticmethod |
|
|
|
|
def _should_display_progress(display_progress: bool): |
|
|
|
|
"""Only display progress on DP rank 0, TP rank 0 and PP last rank""" |
|
|
|
|
return (display_progress and is_dp_rank_0() and is_tp_rank_0() |
|
|
|
|
and is_no_pp_or_last_stage()) |
|
|
|
|
return (display_progress and is_dp_rank_0() and is_tp_rank_0() and is_no_pp_or_last_stage()) |
|
|
|
|
|
|
|
|
|
def _train_epoch( |
|
|
|
|
self, |
|
|
|
|
train_dataloader: DataLoader, |
|
|
|
|
epoch: int = None, |
|
|
|
|
display_progress: bool = False, |
|
|
|
|
return_output_label: bool = True, |
|
|
|
|
self, |
|
|
|
|
train_dataloader: DataLoader, |
|
|
|
|
epoch: int = None, |
|
|
|
|
display_progress: bool = False, |
|
|
|
|
return_output_label: bool = True, |
|
|
|
|
): |
|
|
|
|
# set training state |
|
|
|
|
self._engine.train() |
|
|
|
@ -189,9 +185,7 @@ class Trainer:
|
|
|
|
|
return_output_label=return_output_label, |
|
|
|
|
) |
|
|
|
|
self.engine.step() |
|
|
|
|
self._call_timer(action="stop", |
|
|
|
|
item="Train-step", |
|
|
|
|
keep_in_history=True) |
|
|
|
|
self._call_timer(action="stop", item="Train-step", keep_in_history=True) |
|
|
|
|
self._call_hooks("after_train_iter", output=(logits, label, loss)) |
|
|
|
|
|
|
|
|
|
self._cur_step += 1 |
|
|
|
@ -204,18 +198,16 @@ class Trainer:
|
|
|
|
|
if self._exceed_max_step(): |
|
|
|
|
break |
|
|
|
|
|
|
|
|
|
self._call_timer(action="stop", |
|
|
|
|
item="Train-epoch", |
|
|
|
|
keep_in_history=True) |
|
|
|
|
self._call_timer(action="stop", item="Train-epoch", keep_in_history=True) |
|
|
|
|
self._call_hooks("after_train_epoch") |
|
|
|
|
self._call_timer(action="reset", item="Train-epoch") |
|
|
|
|
|
|
|
|
|
def _eval( |
|
|
|
|
self, |
|
|
|
|
test_dataloader: DataLoader, |
|
|
|
|
epoch: int = None, |
|
|
|
|
display_progress: bool = False, |
|
|
|
|
return_output_label: bool = True, |
|
|
|
|
self, |
|
|
|
|
test_dataloader: DataLoader, |
|
|
|
|
epoch: int = None, |
|
|
|
|
display_progress: bool = False, |
|
|
|
|
return_output_label: bool = True, |
|
|
|
|
): |
|
|
|
|
# switch engine status |
|
|
|
|
self._engine.eval() |
|
|
|
@ -244,19 +236,14 @@ class Trainer:
|
|
|
|
|
return_loss=True, |
|
|
|
|
return_output_label=return_output_label, |
|
|
|
|
) |
|
|
|
|
self._call_timer(action="stop", |
|
|
|
|
item="Test-step", |
|
|
|
|
keep_in_history=True) |
|
|
|
|
self._call_hooks("after_test_iter", |
|
|
|
|
output=(logits, label, loss)) |
|
|
|
|
self._call_timer(action="stop", item="Test-step", keep_in_history=True) |
|
|
|
|
self._call_hooks("after_test_iter", output=(logits, label, loss)) |
|
|
|
|
|
|
|
|
|
if display_progress: |
|
|
|
|
if "step_metrics" in self.states: |
|
|
|
|
progress.set_postfix(**self.states["step_metrics"]) |
|
|
|
|
|
|
|
|
|
self._call_timer(action="stop", |
|
|
|
|
item="Test-epoch", |
|
|
|
|
keep_in_history=True) |
|
|
|
|
self._call_timer(action="stop", item="Test-epoch", keep_in_history=True) |
|
|
|
|
self._call_hooks("after_test_epoch") |
|
|
|
|
self._call_hooks("after_test") |
|
|
|
|
self._call_timer(action="reset", item="Test-step") |
|
|
|
@ -266,15 +253,15 @@ class Trainer:
|
|
|
|
|
return self._max_steps is not None and self._cur_step >= self._max_steps |
|
|
|
|
|
|
|
|
|
def fit( |
|
|
|
|
self, |
|
|
|
|
train_dataloader: DataLoader, |
|
|
|
|
epochs: int, |
|
|
|
|
max_steps: int = None, |
|
|
|
|
test_dataloader: DataLoader = None, |
|
|
|
|
test_interval: int = 1, |
|
|
|
|
hooks: List[BaseHook] = None, |
|
|
|
|
display_progress: bool = False, |
|
|
|
|
return_output_label: bool = True, |
|
|
|
|
self, |
|
|
|
|
train_dataloader: DataLoader, |
|
|
|
|
epochs: int, |
|
|
|
|
max_steps: int = None, |
|
|
|
|
test_dataloader: DataLoader = None, |
|
|
|
|
test_interval: int = 1, |
|
|
|
|
hooks: List[BaseHook] = None, |
|
|
|
|
display_progress: bool = False, |
|
|
|
|
return_output_label: bool = True, |
|
|
|
|
): |
|
|
|
|
r"""Trains the model to fit training data. |
|
|
|
|
|
|
|
|
@ -303,9 +290,11 @@ class Trainer:
|
|
|
|
|
# reset hooks |
|
|
|
|
self._reset_states() |
|
|
|
|
if hooks is not None: |
|
|
|
|
assert isinstance( |
|
|
|
|
hooks, list |
|
|
|
|
), f"expected argument hooks be to list, but got {type(hooks)}" |
|
|
|
|
assert isinstance(hooks, list), f"expected argument hooks be to list, but got {type(hooks)}" |
|
|
|
|
|
|
|
|
|
for hook in hooks: |
|
|
|
|
assert isinstance(hook, BaseHook), \ |
|
|
|
|
f'expected the hook to be of type BaseHook, but got {type(hook)}' |
|
|
|
|
else: |
|
|
|
|
hooks = [] |
|
|
|
|
self.hooks = hooks |
|
|
|
@ -316,9 +305,7 @@ class Trainer:
|
|
|
|
|
f"Using {hook.__class__.__name__} for training, priority = {hook.priority}", |
|
|
|
|
ranks=[0], |
|
|
|
|
) |
|
|
|
|
self._logger.info( |
|
|
|
|
"Lower value means higher priority for calling hook function", |
|
|
|
|
ranks=[0]) |
|
|
|
|
self._logger.info("Lower value means higher priority for calling hook function", ranks=[0]) |
|
|
|
|
self._call_hooks("after_hook_is_attached") |
|
|
|
|
|
|
|
|
|
self._engine.train() |
|
|
|
@ -360,11 +347,11 @@ class Trainer:
|
|
|
|
|
self._call_timer("reset", "Train-epoch") |
|
|
|
|
|
|
|
|
|
def evaluate( |
|
|
|
|
self, |
|
|
|
|
test_dataloader: DataLoader, |
|
|
|
|
hooks: List[BaseHook] = None, |
|
|
|
|
display_progress: bool = False, |
|
|
|
|
return_output_label: bool = True, |
|
|
|
|
self, |
|
|
|
|
test_dataloader: DataLoader, |
|
|
|
|
hooks: List[BaseHook] = None, |
|
|
|
|
display_progress: bool = False, |
|
|
|
|
return_output_label: bool = True, |
|
|
|
|
): |
|
|
|
|
"""Evaluates the model with testing data. |
|
|
|
|
|
|
|
|
@ -381,9 +368,7 @@ class Trainer:
|
|
|
|
|
# reset hooks |
|
|
|
|
self._reset_states() |
|
|
|
|
if hooks is not None: |
|
|
|
|
assert isinstance( |
|
|
|
|
hooks, list |
|
|
|
|
), f"expected argument hooks be to list, but got {type(hooks)}" |
|
|
|
|
assert isinstance(hooks, list), f"expected argument hooks be to list, but got {type(hooks)}" |
|
|
|
|
else: |
|
|
|
|
hooks = [] |
|
|
|
|
self.hooks = hooks |
|
|
|
@ -394,9 +379,7 @@ class Trainer:
|
|
|
|
|
f"Using {hook.__class__.__name__} for training, priority = {hook.priority}", |
|
|
|
|
ranks=[0], |
|
|
|
|
) |
|
|
|
|
self._logger.info( |
|
|
|
|
"Lower value means higher priority for calling hook function", |
|
|
|
|
ranks=[0]) |
|
|
|
|
self._logger.info("Lower value means higher priority for calling hook function", ranks=[0]) |
|
|
|
|
self._call_hooks("after_hook_is_attached") |
|
|
|
|
|
|
|
|
|
# eval |
|
|
|
@ -406,7 +389,7 @@ class Trainer:
|
|
|
|
|
return_output_label=return_output_label, |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
def predict(self, data: Union[Tensor, List[Tensor]]): |
|
|
|
|
def predict(self, data: Union[Any, List[Any]]): |
|
|
|
|
"""Uses trained model to make a prediction for a tensor or a tensor list. |
|
|
|
|
|
|
|
|
|
Args: |
|
|
|
@ -416,17 +399,11 @@ class Trainer:
|
|
|
|
|
:class:`torch.tensor`: The output of model as the prediction |
|
|
|
|
""" |
|
|
|
|
# predict without labels |
|
|
|
|
if isinstance(data, (list, tuple)): |
|
|
|
|
assert isinstance(data[0], Tensor) |
|
|
|
|
else: |
|
|
|
|
assert isinstance(data, Tensor) |
|
|
|
|
self._engine.eval() |
|
|
|
|
|
|
|
|
|
# prepare a list of (data, label) to make it iterable |
|
|
|
|
# for compatibility with schedule |
|
|
|
|
simple_dataloader = [(data, None)] |
|
|
|
|
data_iter = iter(simple_dataloader) |
|
|
|
|
output, _, _ = self.engine.execute_schedule(data_iter, |
|
|
|
|
forward_only=True, |
|
|
|
|
return_loss=False) |
|
|
|
|
output, _, _ = self.engine.execute_schedule(data_iter, forward_only=True, return_loss=False) |
|
|
|
|
return output |
|
|
|
|