mirror of https://github.com/hpcaitech/ColossalAI
aibig-modeldata-parallelismdeep-learningdistributed-computingfoundation-modelsheterogeneous-traininghpcinferencelarge-scalemodel-parallelismpipeline-parallelism
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
20 lines
661 B
20 lines
661 B
import torch.nn as nn |
|
from torch.optim import Adam |
|
|
|
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR |
|
|
|
|
|
def test_lr_scheduler_save_load(): |
|
model = nn.Linear(10, 10) |
|
optimizer = Adam(model.parameters(), lr=1e-3) |
|
scheduler = CosineAnnealingWarmupLR(optimizer, total_steps=5, warmup_steps=2) |
|
new_scheduler = CosineAnnealingWarmupLR(optimizer, total_steps=5, warmup_steps=2) |
|
for _ in range(5): |
|
scheduler.step() |
|
state_dict = scheduler.state_dict() |
|
new_scheduler.load_state_dict(state_dict) |
|
assert state_dict == new_scheduler.state_dict() |
|
|
|
|
|
if __name__ == "__main__": |
|
test_lr_scheduler_save_load()
|
|
|