[NFC] polish colossalai/nn/lr_scheduler/multistep.py code style (#1572)

pull/1550/head
Sze-qq 2022-09-08 16:43:08 +08:00 committed by Frank Lee
parent e4bf7ae667
commit 2144cbae8c
1 changed files with 17 additions and 7 deletions

View File

@ -22,7 +22,13 @@ class MultiStepLR(_MultiStepLR):
the schedule is started from the beginning or When last_epoch=-1, sets initial lr as lr.
"""
def __init__(self, optimizer, total_steps: int, milestones: List[int] = None, gamma: float = 0.1, last_epoch: int = -1, **kwargs):
def __init__(self,
optimizer,
total_steps: int,
milestones: List[int] = None,
gamma: float = 0.1,
last_epoch: int = -1,
**kwargs):
super().__init__(optimizer, milestones, gamma=gamma, last_epoch=last_epoch)
@ -41,12 +47,16 @@ class MultiStepWarmupLR(WarmupScheduler):
the schedule is started from the beginning or When last_epoch=-1, sets initial lr as lr.
"""
def __init__(self, optimizer, total_steps: int, warmup_steps: int = 0, milestones: List[int] = None,
gamma: float = 0.1, last_epoch: int = -1, **kwargs):
def __init__(self,
optimizer,
total_steps: int,
warmup_steps: int = 0,
milestones: List[int] = None,
gamma: float = 0.1,
last_epoch: int = -1,
**kwargs):
if len(milestones) == 0:
raise ValueError('milestones cannot be empty')
milestones = [
v - warmup_steps for v in milestones if v >= warmup_steps]
base_scheduler = _MultiStepLR(optimizer, milestones=milestones,
gamma=gamma)
milestones = [v - warmup_steps for v in milestones if v >= warmup_steps]
base_scheduler = _MultiStepLR(optimizer, milestones=milestones, gamma=gamma)
super().__init__(optimizer, warmup_steps, base_scheduler, last_epoch=last_epoch)