mirror of https://github.com/hpcaitech/ColossalAI
aibig-modeldata-parallelismdeep-learningdistributed-computingfoundation-modelsheterogeneous-traininghpcinferencelarge-scalemodel-parallelismpipeline-parallelism
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
441 lines
15 KiB
441 lines
15 KiB
from typing import Union, List |
|
from colossalai.context.parallel_mode import ParallelMode |
|
|
|
import torch |
|
from torch import Tensor |
|
from torch.utils.data import DataLoader |
|
from tqdm import tqdm |
|
|
|
from colossalai.core import global_context as gpc |
|
|
|
from colossalai.engine import Engine |
|
from colossalai.engine.schedule import NonPipelineSchedule, BaseSchedule |
|
from colossalai.logging import DistributedLogger |
|
from colossalai.utils import MultiTimer |
|
from colossalai.utils import is_dp_rank_0, is_tp_rank_0, is_no_pp_or_last_stage |
|
from colossalai.trainer.hooks import BaseHook |
|
|
|
|
|
class Trainer: |
|
"""This a class tending for easy deployments of users' training and evaluation instead of |
|
writing their own scripts. It is similar with ``ignite.engine`` and ``keras.engine``, but is |
|
called `Trainer`. |
|
|
|
:param engine: Engine responsible for the process function |
|
:type engine: :class:`Engine` |
|
:param schedule: Schedule responsible for forward and backward steps |
|
:type schedule: :class:`BaseSchedule`, optional |
|
:param timer: Timer used to monitor the whole training |
|
:type timer: :class:`MultiTimer`, optional |
|
:param logger: Logger used to record the whole training |
|
:type logger: :class:`colossalai.logging.DistributedLogger`, optional |
|
""" |
|
def __init__( |
|
self, |
|
engine: Engine, |
|
schedule: BaseSchedule = None, |
|
timer: MultiTimer = None, |
|
logger: DistributedLogger = None, |
|
): |
|
# training-ralated params |
|
self._engine = engine |
|
self._max_epochs = 0 |
|
self._cur_epoch = 0 |
|
self._max_steps = 0 |
|
self._cur_step = 0 |
|
self._steps_per_epoch = 0 |
|
|
|
# misc params |
|
self._logger = logger |
|
self._verbose = logger is not None |
|
|
|
# hooks can store states in this dict, and could be consumed by other hooks |
|
self.states = dict() |
|
|
|
# build hooks |
|
self.hooks = list() |
|
|
|
# multi-timer for time benchmarking |
|
self._timer = timer |
|
|
|
# set schedule which specifies the training iteration for the engine |
|
if schedule is None: |
|
schedule = NonPipelineSchedule() |
|
if (gpc.is_initialized(ParallelMode.PIPELINE) |
|
and gpc.get_world_size(ParallelMode.PIPELINE) > 1): |
|
assert not isinstance( |
|
schedule, NonPipelineSchedule |
|
), "NonPipelineSchedule cannot be used for pipeline parallel training, please use PipelineSchedule instead." |
|
self._schedule = schedule |
|
self._schedule.pre_processing(engine) |
|
|
|
@property |
|
def cur_epoch(self): |
|
"""Returns the index of the current epoch.""" |
|
return self._cur_epoch |
|
|
|
@cur_epoch.setter |
|
def cur_epoch(self, epoch: int): |
|
"""Set how many epochs have been processed.""" |
|
# allow setter for training resumption |
|
self._cur_epoch = epoch |
|
|
|
@property |
|
def cur_step(self): |
|
"""Returns how many iteration steps have been processed.""" |
|
return self._cur_step |
|
|
|
@property |
|
def max_epochs(self): |
|
return self._max_epochs |
|
|
|
@property |
|
def max_steps(self): |
|
return self._max_steps |
|
|
|
@property |
|
def steps_per_epoch(self): |
|
return self._steps_per_epoch |
|
|
|
@property |
|
def engine(self): |
|
return self._engine |
|
|
|
@property |
|
def schedule(self): |
|
return self._schedule |
|
|
|
def _set_current_step(self, epoch: int): |
|
"""Sets current step number. |
|
|
|
:param epoch: Step number to be set |
|
:type epoch: int |
|
""" |
|
self._cur_step = epoch * self._steps_per_epoch |
|
|
|
def _call_timer(self, action: str, item: str, *args, **kwargs) -> None: |
|
"""Call timer funciton with a given timer name. |
|
|
|
:param action: Function to be called on timer |
|
:type action: str |
|
:param item: Name of the timer |
|
:type item: str |
|
:param args: args used for action function |
|
:param kwargs: kwargs used for action function |
|
""" |
|
|
|
if self._timer is not None: |
|
getattr(self._timer, action)(item, *args, **kwargs) |
|
|
|
def _reset_states(self) -> None: |
|
"""Clear trainer states""" |
|
self.states = dict() |
|
|
|
def _call_hooks(self, func, output=None): |
|
"""Calls specific hooks in the current time point. |
|
|
|
:param func: A string represents the time point |
|
:param output: Output of the model after running a iteration or None in any other time points |
|
:type func: str |
|
:type output: optional |
|
""" |
|
# Only after iter hook will receive output |
|
for hook in self.hooks: |
|
if output is None: |
|
getattr(hook, func)(self) |
|
else: |
|
getattr(hook, func)(self, *output) |
|
|
|
@staticmethod |
|
def _should_display_progress(display_progress: bool): |
|
"""Only display progress on DP rank 0, TP rank 0 and PP last rank""" |
|
return (display_progress and is_dp_rank_0() and is_tp_rank_0() |
|
and is_no_pp_or_last_stage()) |
|
|
|
def _train_epoch( |
|
self, |
|
train_dataloader: DataLoader, |
|
epoch: int = None, |
|
display_progress: bool = False, |
|
return_output_label: bool = True, |
|
): |
|
# set training state |
|
self._engine.train() |
|
data_iter = iter(train_dataloader) |
|
progress = range(self._steps_per_epoch) |
|
if display_progress: |
|
if epoch is None: |
|
progress = tqdm(progress, desc="[Train]") |
|
else: |
|
progress = tqdm(progress, desc=f"[Epoch {epoch} / Train]") |
|
|
|
self._call_hooks("before_train_epoch") |
|
self._call_timer(action="start", item="Train-epoch") |
|
for i in progress: |
|
self._call_hooks("before_train_iter") |
|
self._call_timer(action="start", item="Train-step") |
|
|
|
# run 1 training step |
|
self.engine.zero_grad() |
|
logits, label, loss = self.schedule.forward_backward_step( |
|
self.engine, |
|
data_iter, |
|
forward_only=False, |
|
return_loss=True, |
|
return_output_label=return_output_label, |
|
) |
|
self.engine.step() |
|
self._call_timer(action="stop", |
|
item="Train-step", |
|
keep_in_history=True) |
|
self._call_hooks("after_train_iter", output=(logits, label, loss)) |
|
|
|
self._cur_step += 1 |
|
|
|
if display_progress: |
|
if "step_metrics" in self.states: |
|
progress.set_postfix(**self.states["step_metrics"]) |
|
|
|
# stop when max iter is reached |
|
if self._exceed_max_step(): |
|
break |
|
|
|
self._call_timer(action="stop", |
|
item="Train-epoch", |
|
keep_in_history=True) |
|
self._call_hooks("after_train_epoch") |
|
self._call_timer(action="reset", item="Train-epoch") |
|
|
|
def _eval( |
|
self, |
|
test_dataloader: DataLoader, |
|
epoch: int = None, |
|
display_progress: bool = False, |
|
return_output_label: bool = True, |
|
): |
|
# switch engine status |
|
self._engine.eval() |
|
|
|
data_iter = iter(test_dataloader) |
|
num_steps = len(test_dataloader) |
|
|
|
self._call_hooks("before_test") |
|
# prepare progress bar |
|
progress = range(num_steps) |
|
if display_progress: |
|
desc = "Evaluation" |
|
if epoch is not None: |
|
desc = "[Epoch %d / Test]" % epoch |
|
progress = tqdm(progress, desc=desc) |
|
|
|
self._call_hooks("before_test_epoch") |
|
self._call_timer(action="start", item="Test-epoch") |
|
with torch.no_grad(): |
|
for _ in progress: |
|
self._call_hooks("before_test_iter") |
|
self._call_timer(action="start", item="Test-step") |
|
logits, label, loss = self.schedule.forward_backward_step( |
|
self.engine, |
|
data_iter, |
|
forward_only=True, |
|
return_loss=True, |
|
return_output_label=return_output_label, |
|
) |
|
self._call_timer(action="stop", |
|
item="Test-step", |
|
keep_in_history=True) |
|
self._call_hooks("after_test_iter", |
|
output=(logits, label, loss)) |
|
|
|
if display_progress: |
|
if "step_metrics" in self.states: |
|
progress.set_postfix(**self.states["step_metrics"]) |
|
|
|
self._call_timer(action="stop", |
|
item="Test-epoch", |
|
keep_in_history=True) |
|
self._call_hooks("after_test_epoch") |
|
self._call_hooks("after_test") |
|
self._call_timer(action="reset", item="Test-step") |
|
self._call_timer(action="reset", item="Test-epoch") |
|
|
|
def _exceed_max_step(self): |
|
return self._max_steps is not None and self._cur_step >= self._max_steps |
|
|
|
def fit( |
|
self, |
|
train_dataloader: DataLoader, |
|
epochs: int, |
|
max_steps: int = None, |
|
test_dataloader: DataLoader = None, |
|
test_interval: int = 1, |
|
hooks: List[BaseHook] = None, |
|
display_progress: bool = False, |
|
return_output_label: bool = True, |
|
): |
|
"""Trains the model to fit training data. |
|
|
|
:param train_dataloader: DataLoader in training |
|
:param epochs: Maximum number of epoches |
|
:param max_steps: Maximum number of running iterations |
|
:param test_dataloader: DataLoader in testing |
|
:param test_interval: Interval of testing |
|
:param hooks: A list of hooks used in training |
|
:param display_progress: If True, the training progress will be printed |
|
:param return_output_label: If True, the output of model and the label will be returned |
|
|
|
:type train_dataloader: DataLoader |
|
:type epochs: int |
|
:type max_steps: int, optional |
|
:type test_dataloader: DataLoader, optional |
|
:type test_interval: int, optional |
|
:type hooks: list, optional |
|
:type display_progress: bool, optional |
|
:type return_output_label: bool, optional |
|
""" |
|
|
|
# set epochs and steps, consider gradient accumulation |
|
self._steps_per_epoch = len(train_dataloader) |
|
self._max_steps = max_steps |
|
self._max_epochs = epochs |
|
|
|
# check if testing is required |
|
should_test = False |
|
if test_dataloader is not None: |
|
should_test = True |
|
|
|
display_progress = self._should_display_progress(display_progress) |
|
|
|
# reset hooks |
|
self._reset_states() |
|
if hooks is not None: |
|
assert isinstance( |
|
hooks, list |
|
), f"expected argument hooks be to list, but got {type(hooks)}" |
|
else: |
|
hooks = [] |
|
self.hooks = hooks |
|
self.hooks.sort(key=lambda hook: hook.priority) |
|
if self._verbose: |
|
for hook in self.hooks: |
|
self._logger.info( |
|
f"Using {hook.__class__.__name__} for training, priority = {hook.priority}", |
|
ranks=[0], |
|
) |
|
self._logger.info( |
|
"Lower value means higher priority for calling hook function", |
|
ranks=[0]) |
|
self._call_hooks("after_hook_is_attached") |
|
|
|
self._engine.train() |
|
self._call_hooks("before_train") |
|
|
|
# recover step value if resuming training |
|
last_epoch = self._cur_epoch |
|
if self.cur_epoch != 0: |
|
self._set_current_step(last_epoch) |
|
|
|
for epoch in range(last_epoch, epochs): |
|
# train for one epoch |
|
self._train_epoch( |
|
train_dataloader=train_dataloader, |
|
epoch=epoch, |
|
display_progress=display_progress, |
|
return_output_label=return_output_label, |
|
) |
|
|
|
# start eval |
|
if should_test and epoch % test_interval == 0: |
|
self._eval( |
|
test_dataloader=test_dataloader, |
|
display_progress=display_progress, |
|
epoch=epoch, |
|
return_output_label=return_output_label, |
|
) |
|
|
|
self._cur_epoch += 1 |
|
|
|
# check for termination |
|
if self._exceed_max_step(): |
|
self._logger.info( |
|
f"Max number of steps {max_steps} has been reached, training is stopped automatically", |
|
ranks=[0], |
|
) |
|
break |
|
self._call_hooks("after_train") |
|
self._call_timer("reset", "Train-epoch") |
|
|
|
def evaluate( |
|
self, |
|
test_dataloader: DataLoader, |
|
hooks: List[BaseHook] = None, |
|
display_progress: bool = False, |
|
return_output_label: bool = True, |
|
): |
|
"""Evaluates the model with testing data. |
|
|
|
:param test_dataloader: DataLoader in testing |
|
:param hooks: A list of hooks used in evaluation |
|
:param display_progress: If True, the evaluation progress will be printed |
|
:param return_output_label: If True, the output of model and the label will be returned |
|
|
|
:type test_dataloader: DataLoader |
|
:type hooks: list, optional |
|
:type display_progress: bool, optional |
|
:type return_output_label: bool |
|
""" |
|
# set display |
|
display_progress = self._should_display_progress(display_progress) |
|
|
|
# reset hooks |
|
self._reset_states() |
|
if hooks is not None: |
|
assert isinstance( |
|
hooks, list |
|
), f"expected argument hooks be to list, but got {type(hooks)}" |
|
else: |
|
hooks = [] |
|
self.hooks = hooks |
|
self.hooks.sort(key=lambda hook: hook.priority) |
|
if self._verbose: |
|
for hook in self.hooks: |
|
self._logger.info( |
|
f"Using {hook.__class__.__name__} for training, priority = {hook.priority}", |
|
ranks=[0], |
|
) |
|
self._logger.info( |
|
"Lower value means higher priority for calling hook function", |
|
ranks=[0]) |
|
self._call_hooks("after_hook_is_attached") |
|
|
|
# eval |
|
self._eval( |
|
test_dataloader=test_dataloader, |
|
display_progress=display_progress, |
|
return_output_label=return_output_label, |
|
) |
|
|
|
def predict(self, data: Union[Tensor, List[Tensor]]): |
|
"""Uses trained model to make a prediction for a tensor or a tensor list. |
|
|
|
:param data: Data as the input |
|
:type data: Union[Tensor, List[Tensor] |
|
:return: The output of model as the prediction |
|
:rtype: Tensor |
|
""" |
|
# predict without labels |
|
if isinstance(data, (list, tuple)): |
|
assert isinstance(data[0], Tensor) |
|
else: |
|
assert isinstance(data, Tensor) |
|
self._engine.eval() |
|
|
|
# prepare a list of (data, label) to make it iterable |
|
# for compatibility with schedule |
|
simple_dataloader = [(data, None)] |
|
data_iter = iter(simple_dataloader) |
|
output, _, _ = self.schedule.forward_backward_step(self.engine, |
|
data_iter, |
|
forward_only=True, |
|
return_loss=False) |
|
return output
|
|
|