ColossalAI/tests/test_checkpoint_io/test_safetensors_async_io.py

177 lines
5.8 KiB
Python
Raw Normal View History

import tempfile
import pytest
import torch
from safetensors.torch import load_file
from colossalai.checkpoint_io.utils import create_pinned_state_dict
from colossalai.testing import check_state_dict_equal, clear_cache_before_run
from colossalai.utils import get_current_device
from colossalai.utils.safetensors import load_flat, move_and_save, save, save_nested
def gen_optim_state_dict():
return {
"state": {
0: {
"step": torch.tensor(1.0),
"exp_avg": torch.rand((1024, 1024)),
"exp_avg_sq": torch.rand((1024, 1024)),
},
1: {
"step": torch.tensor(1.0),
"exp_avg": torch.rand((1024, 1024)),
"exp_avg_sq": torch.rand((1024, 1024)),
},
2: {
"step": torch.tensor(1.0),
"exp_avg": torch.rand((1024, 1024)),
"exp_avg_sq": torch.rand((1024, 1024)),
},
},
"param_groups": [
{
"lr": 0.001,
"betas": (0.9, 0.999),
"eps": 1e-08,
"weight_decay": 0,
"bias_correction": True,
"params": [
0,
1,
2,
3,
4,
5,
6,
7,
8,
9,
10,
11,
12,
13,
14,
15,
16,
17,
18,
19,
20,
21,
22,
23,
24,
25,
26,
27,
28,
29,
30,
31,
32,
33,
34,
35,
36,
37,
38,
39,
40,
41,
42,
43,
44,
45,
46,
47,
48,
49,
50,
51,
52,
53,
54,
55,
56,
57,
58,
59,
60,
61,
],
}
],
}
def gen_model_state_dict():
return {
"module.weight0": torch.rand((1024, 1024)),
"module.weight1": torch.rand((1024, 1024)),
"module.weight2": torch.rand((1024, 1024)),
}
@pytest.mark.parametrize("empty", [True, False])
@pytest.mark.parametrize("num_threads", [1, 4])
def test_create_pin(empty: bool, num_threads: int):
model_state_dict = gen_model_state_dict()
model_state_dict_pinned = create_pinned_state_dict(model_state_dict, empty=empty, num_threads=num_threads)
for k in model_state_dict.keys():
assert model_state_dict_pinned[k].is_pinned()
if not empty:
assert torch.equal(model_state_dict_pinned[k], model_state_dict[k])
optim_state_dict = gen_optim_state_dict()
optim_state_dict_pinned = create_pinned_state_dict(optim_state_dict, empty=empty, num_threads=num_threads)
for k in optim_state_dict.keys():
if k == "state":
for idx in optim_state_dict[k].keys():
for kk in optim_state_dict[k][idx].keys():
assert optim_state_dict_pinned[k][idx][kk].is_pinned()
if not empty:
assert torch.equal(optim_state_dict_pinned[k][idx][kk], optim_state_dict[k][idx][kk])
else:
assert optim_state_dict[k] == optim_state_dict_pinned[k]
@clear_cache_before_run()
def test_save_load():
with tempfile.TemporaryDirectory() as tempdir:
optimizer_state_dict = gen_optim_state_dict()
optimizer_saved_path = f"{tempdir}/save_optimizer.safetensors"
f_writer = save_nested(optimizer_saved_path, optimizer_state_dict)
f_writer.sync_before_step()
f_writer.synchronize()
del f_writer
load_state_dict = load_flat(optimizer_saved_path)
check_state_dict_equal(load_state_dict, optimizer_state_dict)
optimizer_shard_saved_path = f"{tempdir}/save_optimizer_shard.safetensors"
f_writer = save_nested(optimizer_shard_saved_path, optimizer_state_dict["state"])
f_writer.sync_before_step()
f_writer.synchronize()
del f_writer
load_state_dict_shard = load_flat(optimizer_shard_saved_path)
check_state_dict_equal(load_state_dict_shard, optimizer_state_dict["state"])
model_state_dict = gen_model_state_dict()
model_saved_path = f"{tempdir}/save_model.safetensors"
f_writer = save(model_saved_path, model_state_dict)
f_writer.sync_before_step()
f_writer.synchronize()
del f_writer
load_state_dict = load_file(model_saved_path)
check_state_dict_equal(model_state_dict, load_state_dict)
model_state_dict_cuda = {k: v.to(get_current_device()) for k, v in model_state_dict.items()}
model_state_pinned = {k: v.pin_memory() for k, v in model_state_dict.items()}
model_saved_path = f"{tempdir}/save_model_cuda.safetensors"
f_writer = move_and_save(model_saved_path, model_state_dict_cuda, model_state_pinned)
f_writer.sync_before_step()
f_writer.synchronize()
del f_writer
load_state_dict = load_file(model_saved_path)
check_state_dict_equal(model_state_dict, load_state_dict)