|
|
@ -68,7 +68,9 @@ class OneCycleLR(_OneCycleLR):
|
|
|
|
https://arxiv.org/abs/1708.07120
|
|
|
|
https://arxiv.org/abs/1708.07120
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, optimizer, total_steps: int,
|
|
|
|
def __init__(self,
|
|
|
|
|
|
|
|
optimizer,
|
|
|
|
|
|
|
|
total_steps: int,
|
|
|
|
pct_start=0.3,
|
|
|
|
pct_start=0.3,
|
|
|
|
anneal_strategy='cos',
|
|
|
|
anneal_strategy='cos',
|
|
|
|
cycle_momentum=True,
|
|
|
|
cycle_momentum=True,
|
|
|
@ -76,9 +78,12 @@ class OneCycleLR(_OneCycleLR):
|
|
|
|
max_momentum=0.95,
|
|
|
|
max_momentum=0.95,
|
|
|
|
div_factor=25.0,
|
|
|
|
div_factor=25.0,
|
|
|
|
final_div_factor=10000.0,
|
|
|
|
final_div_factor=10000.0,
|
|
|
|
last_epoch=-1, **kwargs):
|
|
|
|
last_epoch=-1,
|
|
|
|
|
|
|
|
**kwargs):
|
|
|
|
max_lrs = list(map(lambda group: group['lr'], optimizer.param_groups))
|
|
|
|
max_lrs = list(map(lambda group: group['lr'], optimizer.param_groups))
|
|
|
|
super().__init__(optimizer, max_lrs, total_steps=total_steps,
|
|
|
|
super().__init__(optimizer,
|
|
|
|
|
|
|
|
max_lrs,
|
|
|
|
|
|
|
|
total_steps=total_steps,
|
|
|
|
pct_start=pct_start,
|
|
|
|
pct_start=pct_start,
|
|
|
|
anneal_strategy=anneal_strategy,
|
|
|
|
anneal_strategy=anneal_strategy,
|
|
|
|
cycle_momentum=cycle_momentum,
|
|
|
|
cycle_momentum=cycle_momentum,
|
|
|
|