#!/usr/bin/env python # -*- encoding: utf-8 -*- from abc import ABC, abstractmethod from typing import Callable import torch import torch.distributed as dist from colossalai.communication import all_reduce from colossalai.context import ParallelMode from colossalai.core import global_context as gpc from colossalai.registry import HOOKS from colossalai.utils import get_current_device, is_no_pp_or_last_stage from ._base_hook import BaseHook class Metric(ABC): """A basic class of metric collectors. It collects a specific metric during training or evaluation and would always be used with :class:`MetricHook` to help it update its states and show the metric. So please use corresponding hook class to make the metric collector works. Args: epoch_only (bool): Whether the metric only read for the full epoch. """ def __init__(self, epoch_only: bool): # is the metric only read for the full epoch self._epoch_only = epoch_only @property def epoch_only(self): """Returns :attr:`epoch_only`. """ return self._epoch_only @abstractmethod def reset(self) -> None: """Resets the metric to it's initial state. By default, this is called at the start of each epoch. """ pass @abstractmethod def update(self, *args, **kwargs) -> None: """Updates the metric's state using the passed batch output. By default, this is called once for each batch. """ pass @abstractmethod def get_last_step_value(self): """Returns the metric value in the last iteration. """ pass @abstractmethod def get_accumulated_value(self): """Computes the metric based on it's accumulated state. By default, this is called at the end of each epoch. :return: the actual quantity of interest :rtype: Any """ pass @staticmethod @abstractmethod def is_better(a, b) -> bool: """Compares a and b, and returns whether a is better than b :return: The result of comparison :rtype: bool """ pass class LossMetric(Metric): """A metric collector for loss. Args: epoch_only (bool): Whether the metric only read for the full epoch. """ def __init__(self, epoch_only): super().__init__(epoch_only=epoch_only) self.last_step_loss = torch.zeros(1, device=get_current_device()) self.accum_loss = torch.zeros(1, device=get_current_device()) self.count = 0 def reset(self) -> None: """Sets :attr:`last_step_loss` and :attr:`accum_loss` to zero. """ self.last_step_loss.zero_() self.accum_loss.zero_() self.count = 0 def update(self, loss) -> None: """Updates :attr:`last_step_loss` and :attr:`accum_loss` with current loss. It expects the output has loss. Args: loss (:class:`torch.tensor`): Current loss of the output. """ # expect output to be logits, label and loss loss_ = loss.detach() self.last_step_loss.copy_(loss_) self.accum_loss.add_(loss_) self.count += 1 def get_accumulated_value(self): """Returns accumulated loss. """ if gpc.is_initialized(ParallelMode.DATA): dist.all_reduce(self.accum_loss, op=dist.ReduceOp.SUM, group=gpc.get_group(ParallelMode.DATA)) self.accum_loss.div_(gpc.get_world_size(ParallelMode.DATA)) self.accum_loss.div_(self.count) return self.accum_loss.item() def get_last_step_value(self): """Returns :attr:`last_step_loss`. """ return self.last_step_loss @staticmethod def is_better(a, b): return a < b class LearningRateMetric(Metric): """A metric collector for learning rate. Args: epoch_only (bool): Whether the metric only read for the full epoch. initial_lr (float, optional): Initial learning rate, defaults to 0.0. """ def __init__(self, epoch_only: bool, initial_lr: float = 0.): super().__init__(epoch_only=epoch_only) self.lr = initial_lr def reset(self) -> None: pass def update(self, lr) -> None: self.lr = lr def get_last_step_value(self): return self.lr def get_accumulated_value(self): return self.lr @staticmethod def is_better(a, b) -> bool: pass class AccuracyMetric(Metric): """A metric collector for accuracy. It only works for classification tasks. Args: epoch_only (bool): Whether the metric only read for the full epoch. accuracy_func (:class:`typing.Callable`): Accuracy function for the classification task. """ def __init__(self, epoch_only: bool, accuracy_func: Callable): super().__init__(epoch_only=epoch_only) self.acc = accuracy_func self.last_step_sum = torch.zeros(1, device=get_current_device()) self.last_step_correct = torch.zeros(1, device=get_current_device()) self.accumulated_sum = torch.zeros(1, device=get_current_device()) self.accumulated_correct = torch.zeros(1, device=get_current_device()) def reset(self) -> None: self.last_step_sum.zero_() self.last_step_correct.zero_() self.accumulated_sum.zero_() self.accumulated_correct.zero_() def update(self, logits, targets, batch_size) -> None: """Updates last step accuracy and accumulated accuracy with current logits and labels. It expects the output has logits and labels. Args: logits (:class:`torch.tensor`): The logits output of the model. targets (:class:`torch.tensor`): Real labels of the dataset. batch_size (int): Batch size of the task. """ if isinstance(logits, (list, tuple)): logits = logits[0] if isinstance(targets, (list, tuple)): targets = targets[0] # update correct = self.acc(logits, targets) self.last_step_sum.fill_(batch_size) self.last_step_correct.fill_(correct) self.accumulated_sum += self.last_step_sum self.accumulated_correct += self.last_step_correct def get_last_step_value(self): self.last_step_sum = all_reduce(self.last_step_sum, ParallelMode.DATA) self.last_step_correct = all_reduce(self.last_step_correct, ParallelMode.DATA) return (self.last_step_correct / self.last_step_sum).item() def get_accumulated_value(self): self.accumulated_sum = all_reduce(self.accumulated_sum, ParallelMode.DATA) self.accumulated_correct = all_reduce(self.accumulated_correct, ParallelMode.DATA) return (self.accumulated_correct / self.accumulated_sum).item() @staticmethod def is_better(a, b) -> bool: return a > b class MetricHook(BaseHook): """Specialized hook classes for :class:`Metric`. Some help metric collectors initialize, reset and update their states. Others are used to display and record the metric. Args: priority (int): Priority in the printing, hooks with small priority will be printed in front defaults to 1. If different hooks share same priority, the order of printing would depend on the hooks order in the hook list. """ def __init__( self, priority: int, ): super().__init__(priority) self._is_stage_to_compute = is_no_pp_or_last_stage() def _check_metric_states_initialization(self, trainer): if 'metrics' not in trainer.states: self.init_runner_states(trainer, 'metrics', dict(train={}, test={})) @HOOKS.register_module class LossHook(MetricHook): """Specialized hook class for :class:`Loss`. Args: priority (int, optional): Priority in the printing, hooks with small priority will be printed in front defaults to 0. If different hooks share same priority, the order of printing would depend on the hooks order in the hook list. """ def __init__(self, priority: int = 0): super().__init__(priority) def after_hook_is_attached(self, trainer): self._check_metric_states_initialization(trainer) if self._is_stage_to_compute: self.train_loss = LossMetric(epoch_only=False) self.test_loss = LossMetric(epoch_only=True) # register the metric calculator trainer.states['metrics']['train']['Loss'] = self.train_loss trainer.states['metrics']['test']['Loss'] = self.test_loss def before_train_epoch(self, trainer): if self._is_stage_to_compute: self.train_loss.reset() def after_train_iter(self, trainer, logits, label, loss): if self._is_stage_to_compute: self.train_loss.update(loss) def before_test_epoch(self, trainer): if self._is_stage_to_compute: self.test_loss.reset() def after_test_iter(self, trainer, logits, label, loss): if self._is_stage_to_compute: self.test_loss.update(loss) @HOOKS.register_module class AccuracyHook(MetricHook): """Specialized hook class for :class:`Accuracy`. Args: accuracy_func (:class:`typing.Callable`): Accuracy function for the classification task. priority (int, optional): Priority in the printing, hooks with small priority will be printed in front defaults to 0. If different hooks share same priority, the order of printing would depend on the hooks order in the hook list. """ def __init__(self, accuracy_func: Callable, priority: int = 0): super().__init__(priority) self.accuracy_func = accuracy_func def after_hook_is_attached(self, trainer): self._check_metric_states_initialization(trainer) if self._is_stage_to_compute: self.metric = AccuracyMetric(epoch_only=True, accuracy_func=self.accuracy_func) # register the metric trainer.states['metrics']['test']['Accuracy'] = self.metric def before_test(self, trainer): if self._is_stage_to_compute: self.metric.reset() def after_test_iter(self, trainer, logits, targets, *args): if self._is_stage_to_compute: batch_size = trainer.schedule.batch_size self.metric.update(logits, targets, batch_size) class ThroughputMetric(Metric): """Metric for :class:`Throughput`. Args: epoch_only (bool): Whether the metric only read for the full epoch. """ def __init__(self, epoch_only: bool, ignored_steps: int = 0): super().__init__(epoch_only=epoch_only) self.ignored_steps = ignored_steps self.cur_steps = 0 self.accumulated_num_samples = torch.zeros(1, device=get_current_device()) self.accumulated_used_time = torch.zeros(1, device=get_current_device()) self.last_step_num_samples = torch.zeros(1, device=get_current_device()) self.last_step_used_time = torch.zeros(1, device=get_current_device()) def reset(self) -> None: # self.cur_steps = 0 self.accumulated_num_samples.zero_() self.accumulated_used_time.zero_() self.last_step_num_samples.zero_() self.last_step_used_time.zero_() def update(self, num_samples, time) -> None: self.cur_steps += 1 self.last_step_num_samples.fill_(num_samples) self.last_step_used_time.fill_(time) if self.cur_steps >= self.ignored_steps: self.accumulated_num_samples += self.last_step_num_samples self.accumulated_used_time += self.last_step_used_time def get_last_step_value(self): self.last_step_used_time = all_reduce(self.last_step_used_time, ParallelMode.DATA) / \ gpc.get_world_size(ParallelMode.DATA) self.last_step_num_samples = all_reduce(self.last_step_num_samples, ParallelMode.DATA) return (self.last_step_num_samples / (self.last_step_used_time + 1e-12)).item() def get_accumulated_value(self): self.accumulated_used_time = all_reduce(self.accumulated_used_time, ParallelMode.DATA) / \ gpc.get_world_size(ParallelMode.DATA) self.accumulated_num_samples = all_reduce(self.accumulated_num_samples, ParallelMode.DATA) return (self.accumulated_num_samples / (self.accumulated_used_time + 1e-12)).item() @staticmethod def is_better(a, b) -> bool: pass @HOOKS.register_module class ThroughputHook(MetricHook): """Specialized hook class for :class:`Throughput`. Hook to measure execution throughput (samples/sec). Args: ignored_steps (int, optional): the number of initial training steps to ignore. priority (int, optional): Priority in the printing, hooks with small priority will be printed in front defaults to 10. If different hooks share same priority, the order of printing would depend on the hooks order in the hook list. """ def __init__(self, ignored_steps: int = 0, priority: int = 10): super().__init__(priority) self.ignored_steps = ignored_steps def after_hook_is_attached(self, trainer): self._check_metric_states_initialization(trainer) if self._is_stage_to_compute: self.metric = ThroughputMetric(epoch_only=True, ignored_steps=self.ignored_steps) # register the metric trainer.states['metrics']['train']['Throughput'] = self.metric trainer.states['metrics']['test']['Throughput'] = self.metric def before_train_epoch(self, trainer): if self._is_stage_to_compute: self.metric.reset() def after_train_iter(self, trainer, *args): if self._is_stage_to_compute: self.metric.update(trainer.schedule.batch_size, trainer._timer.get_timer('Train-step').get_elapsed_time()) def before_test(self, trainer): if self._is_stage_to_compute: self.metric.reset() def after_test_iter(self, trainer, *args): if self._is_stage_to_compute: self.metric.update(trainer.schedule.batch_size, trainer._timer.get_timer('Test-step').get_elapsed_time())