|
|
|
import tempfile
|
|
|
|
|
|
|
|
import torch
|
|
|
|
from safetensors.torch import load_file
|
|
|
|
|
|
|
|
from colossalai.utils.safetensors import load_flat, move_and_save, save, save_nested
|
|
|
|
|
|
|
|
try:
|
|
|
|
from tensornvme.async_file_io import AsyncFileWriter
|
|
|
|
except ModuleNotFoundError:
|
|
|
|
raise ModuleNotFoundError("Please install tensornvme to use NVMeOptimizer")
|
|
|
|
|
|
|
|
from colossalai.testing import check_state_dict_equal
|
|
|
|
from colossalai.utils import get_current_device
|
|
|
|
|
|
|
|
|
|
|
|
def test_save_load():
|
|
|
|
with tempfile.TemporaryDirectory() as tempdir:
|
|
|
|
optimizer_state_dict = {
|
|
|
|
"state": {
|
|
|
|
0: {
|
|
|
|
"step": torch.tensor(1.0),
|
|
|
|
"exp_avg": torch.rand((1024, 1024)),
|
|
|
|
"exp_avg_sq": torch.rand((1024, 1024)),
|
|
|
|
},
|
|
|
|
1: {
|
|
|
|
"step": torch.tensor(1.0),
|
|
|
|
"exp_avg": torch.rand((1024, 1024)),
|
|
|
|
"exp_avg_sq": torch.rand((1024, 1024)),
|
|
|
|
},
|
|
|
|
2: {
|
|
|
|
"step": torch.tensor(1.0),
|
|
|
|
"exp_avg": torch.rand((1024, 1024)),
|
|
|
|
"exp_avg_sq": torch.rand((1024, 1024)),
|
|
|
|
},
|
|
|
|
},
|
|
|
|
"param_groups": [
|
|
|
|
{
|
|
|
|
"lr": 0.001,
|
|
|
|
"betas": (0.9, 0.999),
|
|
|
|
"eps": 1e-08,
|
|
|
|
"weight_decay": 0,
|
|
|
|
"bias_correction": True,
|
|
|
|
"params": [
|
|
|
|
0,
|
|
|
|
1,
|
|
|
|
2,
|
|
|
|
3,
|
|
|
|
4,
|
|
|
|
5,
|
|
|
|
6,
|
|
|
|
7,
|
|
|
|
8,
|
|
|
|
9,
|
|
|
|
10,
|
|
|
|
11,
|
|
|
|
12,
|
|
|
|
13,
|
|
|
|
14,
|
|
|
|
15,
|
|
|
|
16,
|
|
|
|
17,
|
|
|
|
18,
|
|
|
|
19,
|
|
|
|
20,
|
|
|
|
21,
|
|
|
|
22,
|
|
|
|
23,
|
|
|
|
24,
|
|
|
|
25,
|
|
|
|
26,
|
|
|
|
27,
|
|
|
|
28,
|
|
|
|
29,
|
|
|
|
30,
|
|
|
|
31,
|
|
|
|
32,
|
|
|
|
33,
|
|
|
|
34,
|
|
|
|
35,
|
|
|
|
36,
|
|
|
|
37,
|
|
|
|
38,
|
|
|
|
39,
|
|
|
|
40,
|
|
|
|
41,
|
|
|
|
42,
|
|
|
|
43,
|
|
|
|
44,
|
|
|
|
45,
|
|
|
|
46,
|
|
|
|
47,
|
|
|
|
48,
|
|
|
|
49,
|
|
|
|
50,
|
|
|
|
51,
|
|
|
|
52,
|
|
|
|
53,
|
|
|
|
54,
|
|
|
|
55,
|
|
|
|
56,
|
|
|
|
57,
|
|
|
|
58,
|
|
|
|
59,
|
|
|
|
60,
|
|
|
|
61,
|
|
|
|
],
|
|
|
|
}
|
|
|
|
],
|
|
|
|
}
|
|
|
|
|
|
|
|
optimizer_saved_path = f"{tempdir}/save_optimizer.safetensors"
|
|
|
|
f_writer = AsyncFileWriter(fp=open(optimizer_saved_path, "wb"), n_entries=191, backend="pthread")
|
|
|
|
save_nested(f_writer, optimizer_state_dict)
|
|
|
|
f_writer.sync_before_step()
|
|
|
|
f_writer.synchronize()
|
|
|
|
f_writer.fp.close()
|
|
|
|
load_state_dict = load_flat(optimizer_saved_path)
|
|
|
|
check_state_dict_equal(load_state_dict, optimizer_state_dict)
|
|
|
|
|
|
|
|
optimizer_shard_saved_path = f"{tempdir}/save_optimizer_shard.safetensors"
|
|
|
|
f_writer = AsyncFileWriter(fp=open(optimizer_shard_saved_path, "wb"), n_entries=191, backend="pthread")
|
|
|
|
save_nested(f_writer, optimizer_state_dict["state"])
|
|
|
|
f_writer.sync_before_step()
|
|
|
|
f_writer.synchronize()
|
|
|
|
f_writer.fp.close()
|
|
|
|
load_state_dict_shard = load_flat(optimizer_shard_saved_path)
|
|
|
|
check_state_dict_equal(load_state_dict_shard, optimizer_state_dict["state"])
|
|
|
|
|
|
|
|
model_state_dict = {
|
|
|
|
"module.weight0": torch.rand((1024, 1024)),
|
|
|
|
"module.weight1": torch.rand((1024, 1024)),
|
|
|
|
"module.weight2": torch.rand((1024, 1024)),
|
|
|
|
}
|
|
|
|
model_saved_path = f"{tempdir}/save_model.safetensors"
|
|
|
|
f_writer = AsyncFileWriter(fp=open(model_saved_path, "wb"), n_entries=191, backend="pthread")
|
|
|
|
save(f_writer, model_state_dict)
|
|
|
|
f_writer.sync_before_step()
|
|
|
|
f_writer.synchronize()
|
|
|
|
f_writer.fp.close()
|
|
|
|
load_state_dict = load_file(model_saved_path)
|
|
|
|
check_state_dict_equal(model_state_dict, load_state_dict)
|
|
|
|
|
|
|
|
model_state_dict_cuda = {k: v.to(get_current_device()) for k, v in model_state_dict.items()}
|
|
|
|
model_state_pinned = {k: v.pin_memory() for k, v in model_state_dict.items()}
|
|
|
|
model_saved_path = f"{tempdir}/save_model_cuda.safetensors"
|
|
|
|
f_writer = AsyncFileWriter(fp=open(model_saved_path, "wb"), n_entries=191, backend="pthread")
|
|
|
|
move_and_save(f_writer, model_state_dict_cuda, model_state_pinned)
|
|
|
|
f_writer.sync_before_step()
|
|
|
|
f_writer.synchronize()
|
|
|
|
f_writer.fp.close()
|
|
|
|
load_state_dict = load_file(model_saved_path)
|
|
|
|
check_state_dict_equal(model_state_dict, load_state_dict)
|