You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
ColossalAI/colossalai/legacy/trainer/hooks/_lr_scheduler_hook.py

50 lines
2.0 KiB

from torch import Tensor
from colossalai.legacy.registry import HOOKS
from ._metric_hook import LearningRateMetric, MetricHook
@HOOKS.register_module
class LRSchedulerHook(MetricHook):
r"""Build LR scheduler for trainer.
Args:
lr_scheduler (:class:`colossalai.nn.lr_scheduler`): The specific LR scheduler
in range of ``colossalai.nn.lr_scheduler``, more details about ``lr_scheduler`` could be found in
`lr_scheduler <https://github.com/hpcaitech/ColossalAI/tree/main/colossalai/nn/lr_scheduler>`_.
by_epoch (bool): If `True`, the LR will be scheduled every epoch. Else, the LR will be scheduled every batch.
store_lr_in_state (bool, optional): If `True`, store the learning rate in each state, defaults to `True`.
priority (int, optional): Priority in the printing, hooks with small priority will be printed in front
defaults to 1. If different hooks share same priority, the order of printing would
depend on the hooks order in the hook list.
"""
def __init__(
self,
lr_scheduler,
by_epoch: bool,
store_lr_in_state: bool = True,
priority: int = 1,
):
super().__init__(priority=priority)
self.by_epoch = by_epoch
self.lr_scheduler = lr_scheduler
self.store_lr_in_state = store_lr_in_state
def after_hook_is_attached(self, trainer):
self._check_metric_states_initialization(trainer)
trainer.states["metrics"]["train"]["LR"] = LearningRateMetric(
epoch_only=self.by_epoch, initial_lr=self.lr_scheduler.get_last_lr()[0]
)
def after_train_epoch(self, trainer):
if self.by_epoch:
self.lr_scheduler.step()
trainer.states["metrics"]["train"]["LR"].update(self.lr_scheduler.get_last_lr()[0])
def after_train_iter(self, trainer, output: Tensor, label: Tensor, loss: Tensor):
if not self.by_epoch:
self.lr_scheduler.step()
trainer.states["metrics"]["train"]["LR"].update(self.lr_scheduler.get_last_lr()[0])