2023-09-04 11:56:42 +00:00
|
|
|
from torch.optim.lr_scheduler import ExponentialLR as _ExponentialLR
|
2021-10-28 16:21:23 +00:00
|
|
|
from torch.optim.lr_scheduler import LambdaLR as _LambdaLR
|
|
|
|
from torch.optim.lr_scheduler import MultiplicativeLR as _MultiplicativeLR
|
|
|
|
from torch.optim.lr_scheduler import StepLR as _StepLR
|
|
|
|
|
|
|
|
|
|
|
|
class LambdaLR(_LambdaLR):
|
|
|
|
"""Sets the learning rate of each parameter group to the initial lr
|
|
|
|
times a given function. When last_epoch=-1, sets initial lr as lr.
|
2022-01-21 02:44:30 +00:00
|
|
|
|
2022-03-25 05:02:39 +00:00
|
|
|
Args:
|
|
|
|
optimizer (:class:`torch.optim.Optimizer`): Wrapped optimizer.
|
|
|
|
total_steps (int): Number of total training steps.
|
|
|
|
lr_lambda (Union[``function``, ``list[function]``]): A function which computes a multiplicative
|
|
|
|
factor given an integer parameter epoch, or a list of such functions,
|
|
|
|
one for each group in optimizer.param_groups, defaults to None.
|
|
|
|
last_epoch (int, optional): The index of last epoch, defaults to -1.
|
2021-10-28 16:21:23 +00:00
|
|
|
"""
|
|
|
|
|
2021-11-18 11:45:06 +00:00
|
|
|
def __init__(self, optimizer, total_steps, lr_lambda=None, last_epoch: int = -1) -> None:
|
|
|
|
super().__init__(optimizer, lr_lambda, last_epoch=last_epoch)
|
2021-10-28 16:21:23 +00:00
|
|
|
|
|
|
|
|
|
|
|
class MultiplicativeLR(_MultiplicativeLR):
|
|
|
|
"""Multiply the learning rate of each parameter group by the factor given
|
2022-03-25 05:02:39 +00:00
|
|
|
in the specified function. When last_epoch=-1, sets initial lr as lr.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
optimizer (:class:`torch.optim.Optimizer`): Wrapped optimizer.
|
|
|
|
total_steps (int): Number of total training steps.
|
|
|
|
lr_lambda (Union[``function``, ``list[function]``]): A function which computes a multiplicative
|
|
|
|
factor given an integer parameter epoch, or a list of such functions,
|
|
|
|
one for each group in optimizer.param_groups, defaults to None.
|
|
|
|
last_epoch (int, optional): The index of last epoch, defaults to -1.
|
2021-10-28 16:21:23 +00:00
|
|
|
"""
|
|
|
|
|
2021-11-18 11:45:06 +00:00
|
|
|
def __init__(self, optimizer, total_steps, lr_lambda=None, last_epoch: int = -1) -> None:
|
|
|
|
super().__init__(optimizer, lr_lambda, last_epoch=last_epoch)
|
2021-10-28 16:21:23 +00:00
|
|
|
|
|
|
|
|
|
|
|
class StepLR(_StepLR):
|
|
|
|
"""Decays the learning rate of each parameter group by gamma every
|
|
|
|
step_size epochs. Notice that such decay can happen simultaneously with
|
|
|
|
other changes to the learning rate from outside this scheduler. When
|
2022-03-25 05:02:39 +00:00
|
|
|
last_epoch=-1, sets initial lr as lr.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
optimizer (:class:`torch.optim.Optimizer`): Wrapped optimizer.
|
|
|
|
total_steps (int): Number of total training steps.
|
|
|
|
step_size (int, optional): Period of learning rate decay, defaults to 1.
|
|
|
|
gamma (float, optional): Multiplicative factor of learning rate decay, defaults to 0.1.
|
|
|
|
last_epoch (int, optional): The index of last epoch, defaults to -1.
|
2021-10-28 16:21:23 +00:00
|
|
|
"""
|
|
|
|
|
2021-11-18 11:45:06 +00:00
|
|
|
def __init__(self, optimizer, total_steps, step_size: int = 1, gamma: float = 0.1, last_epoch: int = -1) -> None:
|
2022-09-08 08:33:23 +00:00
|
|
|
super().__init__(optimizer, step_size, gamma=gamma, last_epoch=last_epoch)
|
2021-10-28 16:21:23 +00:00
|
|
|
|
|
|
|
|
2021-11-18 11:45:06 +00:00
|
|
|
class ExponentialLR(_ExponentialLR):
|
2021-10-28 16:21:23 +00:00
|
|
|
"""Decays the learning rate of each parameter group by gamma every epoch.
|
|
|
|
When last_epoch=-1, sets initial lr as lr
|
2022-01-21 02:44:30 +00:00
|
|
|
|
2022-03-25 05:02:39 +00:00
|
|
|
Args:
|
|
|
|
optimizer (Union[:class:`torch.optim.Optimizer`, :class:`colossalai.nn.optimizer`]): Wrapped optimizer.
|
|
|
|
total_steps (int): Number of total training steps.
|
|
|
|
gamma (float, optional): Multiplicative factor of learning rate decay, defaults to 1.0.
|
|
|
|
last_epoch (int, optional): The index of last epoch, defaults to -1.
|
2021-10-28 16:21:23 +00:00
|
|
|
"""
|
|
|
|
|
2022-09-08 08:33:23 +00:00
|
|
|
def __init__(self, optimizer, total_steps, gamma: float = 1.0, last_epoch: int = -1) -> None:
|
2021-11-18 11:45:06 +00:00
|
|
|
super().__init__(optimizer, gamma, last_epoch=last_epoch)
|