mirror of https://github.com/hpcaitech/ColossalAI
Browse Source
* support async optimizer save/load * fix * fix * support pin mem * Update low_level_zero_plugin.py * fix * fix * fix * fix * fixpull/6124/merge
flybird11111
4 days ago
committed by
Hongxin Liu
15 changed files with 374 additions and 46 deletions
@ -0,0 +1,127 @@
|
||||
import tempfile |
||||
from copy import deepcopy |
||||
|
||||
import torch |
||||
|
||||
from colossalai.utils.safetensors import load_flat, save_nested |
||||
|
||||
try: |
||||
from tensornvme.async_file_io import AsyncFileWriter |
||||
except ModuleNotFoundError: |
||||
raise ModuleNotFoundError("Please install tensornvme to use NVMeOptimizer") |
||||
|
||||
from colossalai.testing import check_state_dict_equal |
||||
|
||||
|
||||
def test_save_load(): |
||||
with tempfile.TemporaryDirectory() as tempdir: |
||||
optimizer_state_dict = { |
||||
0: {"step": torch.tensor(1.0), "exp_avg": torch.rand((1024, 1024)), "exp_avg_sq": torch.rand((1024, 1024))}, |
||||
1: {"step": torch.tensor(1.0), "exp_avg": torch.rand((1024, 1024)), "exp_avg_sq": torch.rand((1024, 1024))}, |
||||
2: {"step": torch.tensor(1.0), "exp_avg": torch.rand((1024, 1024)), "exp_avg_sq": torch.rand((1024, 1024))}, |
||||
} |
||||
# group_dict = {"param_groups": [0, 1, 2]} |
||||
group_dict = { |
||||
"param_groups": [ |
||||
{ |
||||
"lr": 0.001, |
||||
"betas": (0.9, 0.999), |
||||
"eps": 1e-08, |
||||
"weight_decay": 0, |
||||
"bias_correction": True, |
||||
"params": [ |
||||
0, |
||||
1, |
||||
2, |
||||
3, |
||||
4, |
||||
5, |
||||
6, |
||||
7, |
||||
8, |
||||
9, |
||||
10, |
||||
11, |
||||
12, |
||||
13, |
||||
14, |
||||
15, |
||||
16, |
||||
17, |
||||
18, |
||||
19, |
||||
20, |
||||
21, |
||||
22, |
||||
23, |
||||
24, |
||||
25, |
||||
26, |
||||
27, |
||||
28, |
||||
29, |
||||
30, |
||||
31, |
||||
32, |
||||
33, |
||||
34, |
||||
35, |
||||
36, |
||||
37, |
||||
38, |
||||
39, |
||||
40, |
||||
41, |
||||
42, |
||||
43, |
||||
44, |
||||
45, |
||||
46, |
||||
47, |
||||
48, |
||||
49, |
||||
50, |
||||
51, |
||||
52, |
||||
53, |
||||
54, |
||||
55, |
||||
56, |
||||
57, |
||||
58, |
||||
59, |
||||
60, |
||||
61, |
||||
], |
||||
} |
||||
] |
||||
} |
||||
metadata = deepcopy(group_dict) |
||||
optimizer_saved_path = f"{tempdir}/save_optimizer.safetensors" |
||||
f_writer = AsyncFileWriter(fp=open(optimizer_saved_path, "wb"), n_entries=191, backend="pthread") |
||||
|
||||
save_nested(f_writer, optimizer_state_dict, metadata) |
||||
f_writer.sync_before_step() |
||||
f_writer.synchronize() |
||||
f_writer.fp.close() |
||||
|
||||
load_state_dict = load_flat(optimizer_saved_path) |
||||
state_dict = load_state_dict["state"] |
||||
group = {"param_groups": load_state_dict["param_groups"]} |
||||
check_state_dict_equal(optimizer_state_dict, state_dict) |
||||
check_state_dict_equal(group_dict, group) |
||||
|
||||
model_state_dict = { |
||||
"module.weight0": torch.rand((1024, 1024)), |
||||
"module.weight1": torch.rand((1024, 1024)), |
||||
"module.weight2": torch.rand((1024, 1024)), |
||||
} |
||||
model_saved_path = f"{tempdir}/save_model.safetensors" |
||||
f_writer = AsyncFileWriter(fp=open(model_saved_path, "wb"), n_entries=191, backend="pthread") |
||||
save_nested(f_writer, model_state_dict) |
||||
f_writer.sync_before_step() |
||||
f_writer.synchronize() |
||||
f_writer.fp.close() |
||||
|
||||
load_state_dict = load_flat(model_saved_path) |
||||
check_state_dict_equal(model_state_dict, load_state_dict) |
Loading…
Reference in new issue