mirror of https://github.com/hpcaitech/ColossalAI
248 lines
9.2 KiB
Python
248 lines
9.2 KiB
Python
#!/usr/bin/env python
|
|
# -*- encoding: utf-8 -*-
|
|
|
|
import os
|
|
import os.path as osp
|
|
|
|
import torch
|
|
from tensorboardX import SummaryWriter
|
|
|
|
from colossalai.context import ParallelMode
|
|
from colossalai.core import global_context as gpc
|
|
from colossalai.registry import HOOKS
|
|
from colossalai.trainer._trainer import Trainer
|
|
from colossalai.utils import get_global_multitimer, set_global_multitimer_status, report_memory_usage, is_dp_rank_0, \
|
|
is_tp_rank_0, is_no_pp_or_last_stage
|
|
from ._metric_hook import MetricHook
|
|
|
|
|
|
def _format_number(val):
|
|
if isinstance(val, float):
|
|
return f'{val:.5f}'
|
|
elif torch.is_floating_point(val):
|
|
return f'{val.item():.5f}'
|
|
return val
|
|
|
|
|
|
class EpochIntervalHook(MetricHook):
|
|
def __init__(self, trainer: Trainer, interval: int = 1, priority: int = 1):
|
|
super().__init__(trainer, priority)
|
|
self._interval = interval
|
|
|
|
def _is_epoch_to_log(self):
|
|
return self.trainer.cur_epoch % self._interval == 0
|
|
|
|
|
|
@HOOKS.register_module
|
|
class LogMetricByEpochHook(EpochIntervalHook):
|
|
"""Specialized Hook to record the metric to log.
|
|
|
|
:param trainer: Trainer attached with current hook
|
|
:type trainer: Trainer
|
|
:param interval: Recording interval
|
|
:type interval: int, optional
|
|
:param priority: Priority in the printing, hooks with small priority will be printed in front
|
|
:type priority: int, optional
|
|
"""
|
|
|
|
def __init__(self, trainer: Trainer, interval: int = 1, priority: int = 1) -> None:
|
|
super().__init__(trainer=trainer, interval=interval, priority=priority)
|
|
self._is_rank_to_log = is_dp_rank_0() and is_tp_rank_0() and is_no_pp_or_last_stage()
|
|
|
|
def _get_str(self, mode):
|
|
msg = []
|
|
for metric_name, metric_calculator in self.trainer.states['metrics'][mode].items():
|
|
msg.append(
|
|
f'{metric_name} = {_format_number(metric_calculator.get_accumulated_value())}')
|
|
msg = ', '.join(msg)
|
|
return msg
|
|
|
|
def after_train_epoch(self):
|
|
if self._is_epoch_to_log():
|
|
msg = self._get_str(mode='train')
|
|
|
|
if self._is_rank_to_log:
|
|
self.logger.info(
|
|
f'Training - Epoch {self.trainer.cur_epoch} - {self.__class__.__name__}: {msg}')
|
|
|
|
def after_test_epoch(self):
|
|
if self._is_epoch_to_log():
|
|
msg = self._get_str(mode='test')
|
|
if self._is_rank_to_log:
|
|
self.logger.info(
|
|
f'Testing - Epoch {self.trainer.cur_epoch} - {self.__class__.__name__}: {msg}')
|
|
|
|
|
|
@HOOKS.register_module
|
|
class TensorboardHook(MetricHook):
|
|
"""Specialized Hook to record the metric to Tensorboard.
|
|
|
|
:param trainer: Trainer attached with current hook
|
|
:type trainer: Trainer
|
|
:param log_dir: Directory of log
|
|
:type log_dir: str, optional
|
|
:param priority: Priority in the printing, hooks with small priority will be printed in front
|
|
:type priority: int, optional
|
|
"""
|
|
|
|
def __init__(self, trainer: Trainer, log_dir: str, priority: int = 1) -> None:
|
|
super().__init__(trainer=trainer, priority=priority)
|
|
self._is_rank_to_log = is_no_pp_or_last_stage()
|
|
|
|
if self._is_rank_to_log:
|
|
# create workspace on only one rank
|
|
if gpc.is_initialized(ParallelMode.GLOBAL):
|
|
rank = gpc.get_global_rank()
|
|
else:
|
|
rank = 0
|
|
|
|
log_dir = osp.join(log_dir, f'rank_{rank}')
|
|
|
|
# create workspace
|
|
if not osp.exists(log_dir):
|
|
os.makedirs(log_dir)
|
|
|
|
self.writer = SummaryWriter(
|
|
log_dir=log_dir, filename_suffix=f'_rank_{rank}')
|
|
|
|
def after_train_iter(self, *args):
|
|
for metric_name, metric_calculator in self.trainer.states['metrics']['train'].items():
|
|
if metric_calculator.epoch_only:
|
|
continue
|
|
val = metric_calculator.get_last_step_value()
|
|
if self._is_rank_to_log:
|
|
self.writer.add_scalar(
|
|
f'{metric_name}/train', val, self.trainer.cur_step)
|
|
|
|
def after_test_iter(self, *args):
|
|
for metric_name, metric_calculator in self.trainer.states['metrics']['test'].items():
|
|
if metric_calculator.epoch_only:
|
|
continue
|
|
val = metric_calculator.get_last_step_value()
|
|
if self._is_rank_to_log:
|
|
self.writer.add_scalar(f'{metric_name}/test', val,
|
|
self.trainer.cur_step)
|
|
|
|
def after_test_epoch(self):
|
|
for metric_name, metric_calculator in self.trainer.states['metrics']['test'].items():
|
|
if metric_calculator.epoch_only:
|
|
val = metric_calculator.get_accumulated_value()
|
|
if self._is_rank_to_log:
|
|
self.writer.add_scalar(f'{metric_name}/test', val,
|
|
self.trainer.cur_step)
|
|
|
|
def after_train_epoch(self):
|
|
for metric_name, metric_calculator in self.trainer.states['metrics']['train'].items():
|
|
if metric_calculator.epoch_only:
|
|
val = metric_calculator.get_accumulated_value()
|
|
if self._is_rank_to_log:
|
|
self.writer.add_scalar(f'{metric_name}/train', val,
|
|
self.trainer.cur_step)
|
|
|
|
|
|
@HOOKS.register_module
|
|
class LogTimingByEpochHook(EpochIntervalHook):
|
|
"""Specialized Hook to write timing record to log.
|
|
|
|
:param trainer: Trainer attached with current hook
|
|
:type trainer: Trainer
|
|
:param interval: Recording interval
|
|
:type interval: int, optional
|
|
:param priority: Priority in the printing, hooks with small priority will be printed in front
|
|
:type priority: int, optional
|
|
:param log_eval: Whether writes in evaluation
|
|
:type log_eval: bool, optional
|
|
"""
|
|
|
|
def __init__(self,
|
|
trainer: Trainer,
|
|
interval: int = 1,
|
|
priority: int = 1,
|
|
log_eval: bool = True
|
|
) -> None:
|
|
super().__init__(trainer=trainer, interval=interval, priority=priority)
|
|
set_global_multitimer_status(True)
|
|
self._global_timer = get_global_multitimer()
|
|
self._log_eval = log_eval
|
|
self._is_rank_to_log = is_dp_rank_0() and is_tp_rank_0()
|
|
|
|
def _get_message(self):
|
|
msg = []
|
|
for timer_name, timer in self._global_timer:
|
|
last_elapsed_time = timer.get_elapsed_time()
|
|
if timer.has_history:
|
|
history_mean = timer.get_history_mean()
|
|
history_sum = timer.get_history_sum()
|
|
msg.append(
|
|
f'{timer_name}: last elapsed time = {last_elapsed_time}, '
|
|
f'history sum = {history_sum}, history mean = {history_mean}')
|
|
else:
|
|
msg.append(
|
|
f'{timer_name}: last elapsed time = {last_elapsed_time}')
|
|
|
|
msg = ', '.join(msg)
|
|
return msg
|
|
|
|
def after_train_epoch(self):
|
|
"""Writes log after finishing a training epoch.
|
|
"""
|
|
if self._is_epoch_to_log() and self._is_rank_to_log:
|
|
msg = self._get_message()
|
|
self.logger.info(
|
|
f'Training - Epoch {self.trainer.cur_epoch} - {self.__class__.__name__}: {msg}')
|
|
|
|
def after_test_epoch(self):
|
|
"""Writes log after finishing a testing epoch.
|
|
"""
|
|
if self._is_epoch_to_log() and self._is_rank_to_log and self._log_eval:
|
|
msg = self._get_message()
|
|
self.logger.info(
|
|
f'Testing - Epoch {self.trainer.cur_epoch} - {self.__class__.__name__}: {msg}')
|
|
|
|
|
|
@HOOKS.register_module
|
|
class LogMemoryByEpochHook(EpochIntervalHook):
|
|
"""Specialized Hook to write memory usage record to log.
|
|
|
|
:param trainer: Trainer attached with current hook
|
|
:type trainer: Trainer
|
|
:param interval: Recording interval
|
|
:type interval: int, optional
|
|
:param priority: Priority in the printing, hooks with small priority will be printed in front
|
|
:type priority: int, optional
|
|
:param log_eval: Whether writes in evaluation
|
|
:type log_eval: bool, optional
|
|
"""
|
|
|
|
def __init__(self,
|
|
trainer: Trainer,
|
|
interval: int = 1,
|
|
priority: int = 1,
|
|
log_eval: bool = True
|
|
) -> None:
|
|
super().__init__(trainer=trainer, interval=interval, priority=priority)
|
|
set_global_multitimer_status(True)
|
|
self._global_timer = get_global_multitimer()
|
|
self._log_eval = log_eval
|
|
self._is_rank_to_log = is_dp_rank_0() and is_tp_rank_0()
|
|
|
|
def before_train(self):
|
|
"""Resets before training.
|
|
"""
|
|
if self._is_epoch_to_log() and self._is_rank_to_log:
|
|
report_memory_usage('before-train')
|
|
|
|
def after_train_epoch(self):
|
|
"""Writes log after finishing a training epoch.
|
|
"""
|
|
if self._is_epoch_to_log() and self._is_rank_to_log:
|
|
report_memory_usage(
|
|
f'After Train - Epoch {self.trainer.cur_epoch} - {self.__class__.__name__}')
|
|
|
|
def after_test(self):
|
|
"""Reports after testing.
|
|
"""
|
|
if self._is_epoch_to_log() and self._is_rank_to_log and self._log_eval:
|
|
report_memory_usage(
|
|
f'After Test - Epoch {self.trainer.cur_epoch} - {self.__class__.__name__}')
|