mirror of https://github.com/InternLM/InternLM
37 lines
988 B
Python
37 lines
988 B
Python
#!/usr/bin/env python
|
|
# -*- encoding: utf-8 -*-
|
|
|
|
import torch
|
|
|
|
|
|
class Beta2Scheduler:
|
|
"""
|
|
Beta2Scheduler
|
|
"""
|
|
|
|
def __init__(self, optimizer: torch.optim.Adam, init_beta2, c=0.8, cur_iter=-1):
|
|
self.cur_iter = 0 if cur_iter == -1 else cur_iter
|
|
self.init_beta2 = init_beta2
|
|
self.c = c
|
|
self.optimizer = optimizer
|
|
assert isinstance(
|
|
optimizer, (torch.optim.Adam, torch.optim.AdamW)
|
|
), "should use Adam optimzier, which has beta2"
|
|
|
|
def step(self, cur_iter=None):
|
|
if cur_iter is None:
|
|
self.cur_iter += 1
|
|
else:
|
|
self.cur_iter = cur_iter
|
|
|
|
new_beta2 = self.get_beta2()
|
|
for pg in self.optimizer.param_groups:
|
|
beta1, _ = pg["betas"]
|
|
pg["betas"] = (beta1, new_beta2)
|
|
|
|
def get_beta2(self):
|
|
if self.c <= 0:
|
|
return self.init_beta2
|
|
scale = 1 - (1 / self.cur_iter**self.c)
|
|
return max(self.init_beta2, scale)
|