diff --git a/colossalai/nn/lr_scheduler/cosine.py b/colossalai/nn/lr_scheduler/cosine.py index 179585cd7..aab523bef 100644 --- a/colossalai/nn/lr_scheduler/cosine.py +++ b/colossalai/nn/lr_scheduler/cosine.py @@ -62,8 +62,10 @@ class CosineAnnealingWarmupLR(WarmupScheduler): """ def __init__(self, optimizer, total_steps: int, warmup_steps: int = 0, eta_min: float = 0., last_epoch: int = -1): - base_scheduler = _CosineAnnealingLR( - optimizer, total_steps - warmup_steps, eta_min=eta_min, last_epoch=last_epoch) + base_scheduler = _CosineAnnealingLR(optimizer, + total_steps - warmup_steps, + eta_min=eta_min, + last_epoch=last_epoch) super().__init__(optimizer, warmup_steps, base_scheduler) @@ -81,12 +83,10 @@ class FlatAnnealingLR(DelayerScheduler): def __init__(self, optimizer, total_steps: int, pct_start: float = 0.72, last_epoch: int = -1, **kwargs): if not (0.0 <= pct_start <= 1.0): - raise ValueError( - f'pct_start must >= 0.0 and <= 1.0, got {pct_start}') + raise ValueError(f'pct_start must >= 0.0 and <= 1.0, got {pct_start}') flat_steps = int(total_steps * pct_start) anneal_steps = total_steps - flat_steps - base_scheduler = _CosineAnnealingLR( - optimizer, anneal_steps) + base_scheduler = _CosineAnnealingLR(optimizer, anneal_steps) super().__init__(optimizer, flat_steps, base_scheduler, last_epoch=last_epoch) @@ -105,14 +105,17 @@ class FlatAnnealingWarmupLR(WarmupDelayerScheduler): the schedule is started from the beginning or When last_epoch=-1, sets initial lr as lr. """ - def __init__(self, optimizer, total_steps: int, warmup_steps: int = 0, pct_start: float = 0.72, eta_min: int = 0, - last_epoch: int = -1, **kwargs): + def __init__(self, + optimizer, + total_steps: int, + warmup_steps: int = 0, + pct_start: float = 0.72, + eta_min: int = 0, + last_epoch: int = -1, + **kwargs): if not (0.0 <= pct_start <= 1.0): - raise ValueError( - f'pct_start must >= 0.0 and <= 1.0, got {pct_start}') + raise ValueError(f'pct_start must >= 0.0 and <= 1.0, got {pct_start}') flat_steps = int((total_steps - warmup_steps) * pct_start) anneal_steps = total_steps - warmup_steps - flat_steps - base_scheduler = _CosineAnnealingLR( - optimizer, anneal_steps, eta_min=eta_min) - super().__init__(optimizer, warmup_steps, flat_steps, - base_scheduler, last_epoch=last_epoch) + base_scheduler = _CosineAnnealingLR(optimizer, anneal_steps, eta_min=eta_min) + super().__init__(optimizer, warmup_steps, flat_steps, base_scheduler, last_epoch=last_epoch)