mirror of https://github.com/hpcaitech/ColossalAI
21 lines
661 B
Python
21 lines
661 B
Python
|
import torch.nn as nn
|
||
|
from torch.optim import Adam
|
||
|
|
||
|
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
|
||
|
|
||
|
|
||
|
def test_lr_scheduler_save_load():
|
||
|
model = nn.Linear(10, 10)
|
||
|
optimizer = Adam(model.parameters(), lr=1e-3)
|
||
|
scheduler = CosineAnnealingWarmupLR(optimizer, total_steps=5, warmup_steps=2)
|
||
|
new_scheduler = CosineAnnealingWarmupLR(optimizer, total_steps=5, warmup_steps=2)
|
||
|
for _ in range(5):
|
||
|
scheduler.step()
|
||
|
state_dict = scheduler.state_dict()
|
||
|
new_scheduler.load_state_dict(state_dict)
|
||
|
assert state_dict == new_scheduler.state_dict()
|
||
|
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
test_lr_scheduler_save_load()
|