mirror of https://github.com/hpcaitech/ColossalAI
aibig-modeldata-parallelismdeep-learningdistributed-computingfoundation-modelsheterogeneous-traininghpcinferencelarge-scalemodel-parallelismpipeline-parallelism
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
153 lines
5.4 KiB
153 lines
5.4 KiB
import tempfile |
|
|
|
import torch |
|
from safetensors.torch import load_file |
|
|
|
from colossalai.utils.safetensors import load_flat, move_and_save, save, save_nested |
|
|
|
try: |
|
from tensornvme.async_file_io import AsyncFileWriter |
|
except ModuleNotFoundError: |
|
raise ModuleNotFoundError("Please install tensornvme to use NVMeOptimizer") |
|
|
|
from colossalai.testing import check_state_dict_equal |
|
from colossalai.utils import get_current_device |
|
|
|
|
|
def test_save_load(): |
|
with tempfile.TemporaryDirectory() as tempdir: |
|
optimizer_state_dict = { |
|
"state": { |
|
0: { |
|
"step": torch.tensor(1.0), |
|
"exp_avg": torch.rand((1024, 1024)), |
|
"exp_avg_sq": torch.rand((1024, 1024)), |
|
}, |
|
1: { |
|
"step": torch.tensor(1.0), |
|
"exp_avg": torch.rand((1024, 1024)), |
|
"exp_avg_sq": torch.rand((1024, 1024)), |
|
}, |
|
2: { |
|
"step": torch.tensor(1.0), |
|
"exp_avg": torch.rand((1024, 1024)), |
|
"exp_avg_sq": torch.rand((1024, 1024)), |
|
}, |
|
}, |
|
"param_groups": [ |
|
{ |
|
"lr": 0.001, |
|
"betas": (0.9, 0.999), |
|
"eps": 1e-08, |
|
"weight_decay": 0, |
|
"bias_correction": True, |
|
"params": [ |
|
0, |
|
1, |
|
2, |
|
3, |
|
4, |
|
5, |
|
6, |
|
7, |
|
8, |
|
9, |
|
10, |
|
11, |
|
12, |
|
13, |
|
14, |
|
15, |
|
16, |
|
17, |
|
18, |
|
19, |
|
20, |
|
21, |
|
22, |
|
23, |
|
24, |
|
25, |
|
26, |
|
27, |
|
28, |
|
29, |
|
30, |
|
31, |
|
32, |
|
33, |
|
34, |
|
35, |
|
36, |
|
37, |
|
38, |
|
39, |
|
40, |
|
41, |
|
42, |
|
43, |
|
44, |
|
45, |
|
46, |
|
47, |
|
48, |
|
49, |
|
50, |
|
51, |
|
52, |
|
53, |
|
54, |
|
55, |
|
56, |
|
57, |
|
58, |
|
59, |
|
60, |
|
61, |
|
], |
|
} |
|
], |
|
} |
|
|
|
optimizer_saved_path = f"{tempdir}/save_optimizer.safetensors" |
|
f_writer = AsyncFileWriter(fp=open(optimizer_saved_path, "wb"), n_entries=191, backend="pthread") |
|
save_nested(f_writer, optimizer_state_dict) |
|
f_writer.sync_before_step() |
|
f_writer.synchronize() |
|
f_writer.fp.close() |
|
load_state_dict = load_flat(optimizer_saved_path) |
|
check_state_dict_equal(load_state_dict, optimizer_state_dict) |
|
|
|
optimizer_shard_saved_path = f"{tempdir}/save_optimizer_shard.safetensors" |
|
f_writer = AsyncFileWriter(fp=open(optimizer_shard_saved_path, "wb"), n_entries=191, backend="pthread") |
|
save_nested(f_writer, optimizer_state_dict["state"]) |
|
f_writer.sync_before_step() |
|
f_writer.synchronize() |
|
f_writer.fp.close() |
|
load_state_dict_shard = load_flat(optimizer_shard_saved_path) |
|
check_state_dict_equal(load_state_dict_shard, optimizer_state_dict["state"]) |
|
|
|
model_state_dict = { |
|
"module.weight0": torch.rand((1024, 1024)), |
|
"module.weight1": torch.rand((1024, 1024)), |
|
"module.weight2": torch.rand((1024, 1024)), |
|
} |
|
model_saved_path = f"{tempdir}/save_model.safetensors" |
|
f_writer = AsyncFileWriter(fp=open(model_saved_path, "wb"), n_entries=191, backend="pthread") |
|
save(f_writer, model_state_dict) |
|
f_writer.sync_before_step() |
|
f_writer.synchronize() |
|
f_writer.fp.close() |
|
load_state_dict = load_file(model_saved_path) |
|
check_state_dict_equal(model_state_dict, load_state_dict) |
|
|
|
model_state_dict_cuda = {k: v.to(get_current_device()) for k, v in model_state_dict.items()} |
|
model_state_pinned = {k: v.pin_memory() for k, v in model_state_dict.items()} |
|
model_saved_path = f"{tempdir}/save_model_cuda.safetensors" |
|
f_writer = AsyncFileWriter(fp=open(model_saved_path, "wb"), n_entries=191, backend="pthread") |
|
move_and_save(f_writer, model_state_dict_cuda, model_state_pinned) |
|
f_writer.sync_before_step() |
|
f_writer.synchronize() |
|
f_writer.fp.close() |
|
load_state_dict = load_file(model_saved_path) |
|
check_state_dict_equal(model_state_dict, load_state_dict)
|
|
|