You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
ColossalAI/tests/test_optimizer/test_lr_scheduler.py

21 lines
661 B

import torch.nn as nn
from torch.optim import Adam
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
def test_lr_scheduler_save_load():
model = nn.Linear(10, 10)
optimizer = Adam(model.parameters(), lr=1e-3)
scheduler = CosineAnnealingWarmupLR(optimizer, total_steps=5, warmup_steps=2)
new_scheduler = CosineAnnealingWarmupLR(optimizer, total_steps=5, warmup_steps=2)
for _ in range(5):
scheduler.step()
state_dict = scheduler.state_dict()
new_scheduler.load_state_dict(state_dict)
assert state_dict == new_scheduler.state_dict()
if __name__ == "__main__":
test_lr_scheduler_save_load()