diff --git a/colossalai/nn/lr_scheduler/multistep.py b/colossalai/nn/lr_scheduler/multistep.py index b13bc056b..29531a9e3 100644 --- a/colossalai/nn/lr_scheduler/multistep.py +++ b/colossalai/nn/lr_scheduler/multistep.py @@ -22,7 +22,13 @@ class MultiStepLR(_MultiStepLR): the schedule is started from the beginning or When last_epoch=-1, sets initial lr as lr. """ - def __init__(self, optimizer, total_steps: int, milestones: List[int] = None, gamma: float = 0.1, last_epoch: int = -1, **kwargs): + def __init__(self, + optimizer, + total_steps: int, + milestones: List[int] = None, + gamma: float = 0.1, + last_epoch: int = -1, + **kwargs): super().__init__(optimizer, milestones, gamma=gamma, last_epoch=last_epoch) @@ -41,12 +47,16 @@ class MultiStepWarmupLR(WarmupScheduler): the schedule is started from the beginning or When last_epoch=-1, sets initial lr as lr. """ - def __init__(self, optimizer, total_steps: int, warmup_steps: int = 0, milestones: List[int] = None, - gamma: float = 0.1, last_epoch: int = -1, **kwargs): + def __init__(self, + optimizer, + total_steps: int, + warmup_steps: int = 0, + milestones: List[int] = None, + gamma: float = 0.1, + last_epoch: int = -1, + **kwargs): if len(milestones) == 0: raise ValueError('milestones cannot be empty') - milestones = [ - v - warmup_steps for v in milestones if v >= warmup_steps] - base_scheduler = _MultiStepLR(optimizer, milestones=milestones, - gamma=gamma) + milestones = [v - warmup_steps for v in milestones if v >= warmup_steps] + base_scheduler = _MultiStepLR(optimizer, milestones=milestones, gamma=gamma) super().__init__(optimizer, warmup_steps, base_scheduler, last_epoch=last_epoch)