mirror of https://github.com/hpcaitech/ColossalAI
[checkpointio] fix gemini and hybrid parallel optim checkpoint (#5347)
* [checkpointio] fix hybrid parallel optim checkpoint * [extension] fix cuda extension * [checkpointio] fix gemini optimizer checkpoint * polish codepull/5355/head
parent
c5239840e6
commit
ffffc32dc7
|
@ -1,7 +1,7 @@
|
||||||
import copy
|
import copy
|
||||||
from functools import reduce
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
from functools import reduce
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from shutil import rmtree
|
from shutil import rmtree
|
||||||
from typing import Dict, Iterator, Optional, OrderedDict, Tuple
|
from typing import Dict, Iterator, Optional, OrderedDict, Tuple
|
||||||
|
@ -445,7 +445,11 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
||||||
# Store param groups.
|
# Store param groups.
|
||||||
index_file.append_meta_data("param_groups", param_group_file)
|
index_file.append_meta_data("param_groups", param_group_file)
|
||||||
group_file_path = os.path.join(checkpoint, param_group_file)
|
group_file_path = os.path.join(checkpoint, param_group_file)
|
||||||
save_param_groups(optimizer.param_info, group_file_path)
|
param_groups = [
|
||||||
|
{**group, "params": group_info["params"]}
|
||||||
|
for group, group_info in zip(optimizer.param_groups, optimizer.param_info["param_groups"])
|
||||||
|
]
|
||||||
|
save_param_groups({"param_groups": param_groups}, group_file_path)
|
||||||
# Store index file.
|
# Store index file.
|
||||||
index_file.append_meta_data("total_size", total_size)
|
index_file.append_meta_data("total_size", total_size)
|
||||||
index_file.write_index_file(save_index_file)
|
index_file.write_index_file(save_index_file)
|
||||||
|
@ -504,7 +508,11 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
||||||
# Store param groups.
|
# Store param groups.
|
||||||
final_index_file.append_meta_data("param_groups", param_group_file)
|
final_index_file.append_meta_data("param_groups", param_group_file)
|
||||||
group_file_path = os.path.join(checkpoint, param_group_file)
|
group_file_path = os.path.join(checkpoint, param_group_file)
|
||||||
save_param_groups(optimizer.param_info, group_file_path)
|
param_groups = [
|
||||||
|
{**group, "params": group_info["params"]}
|
||||||
|
for group, group_info in zip(optimizer.param_groups, optimizer.param_info["param_groups"])
|
||||||
|
]
|
||||||
|
save_param_groups({"param_groups": param_groups}, group_file_path)
|
||||||
|
|
||||||
final_index_file.write_index_file(final_index_file_path)
|
final_index_file.write_index_file(final_index_file_path)
|
||||||
rmtree(tmp_index_file_folder)
|
rmtree(tmp_index_file_folder)
|
||||||
|
@ -718,7 +726,11 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
||||||
|
|
||||||
if self.pp_size == 1:
|
if self.pp_size == 1:
|
||||||
# When pipeline is not used, let master rank directly save the collected state_dict.
|
# When pipeline is not used, let master rank directly save the collected state_dict.
|
||||||
state_dict = {"param_groups": optimizer.param_info["param_groups"], "state": local_states}
|
param_groups = [
|
||||||
|
{**group, "params": group_info["params"]}
|
||||||
|
for group, group_info in zip(optimizer.param_groups, optimizer.param_info["param_groups"])
|
||||||
|
]
|
||||||
|
state_dict = {"param_groups": param_groups, "state": local_states}
|
||||||
if self.coordinator.is_master():
|
if self.coordinator.is_master():
|
||||||
save_state_dict(state_dict, checkpoint, use_safetensors=False)
|
save_state_dict(state_dict, checkpoint, use_safetensors=False)
|
||||||
else:
|
else:
|
||||||
|
@ -729,7 +741,11 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
||||||
|
|
||||||
# Only the master rank do the saving.
|
# Only the master rank do the saving.
|
||||||
if self.coordinator.is_master():
|
if self.coordinator.is_master():
|
||||||
state_dict = {"param_groups": optimizer.param_info["param_groups"], "state": dict()}
|
param_groups = [
|
||||||
|
{**group, "params": group_info["params"]}
|
||||||
|
for group, group_info in zip(optimizer.param_groups, optimizer.param_info["param_groups"])
|
||||||
|
]
|
||||||
|
state_dict = {"param_groups": param_groups, "state": dict()}
|
||||||
for _states in states_list:
|
for _states in states_list:
|
||||||
state_dict["state"].update(_states)
|
state_dict["state"].update(_states)
|
||||||
save_state_dict(state_dict, checkpoint, use_safetensors=False)
|
save_state_dict(state_dict, checkpoint, use_safetensors=False)
|
||||||
|
|
|
@ -621,7 +621,10 @@ class GeminiOptimizer(OptimizerWrapper):
|
||||||
Return the param_groups in Pytorch format when saving to checkpoint.
|
Return the param_groups in Pytorch format when saving to checkpoint.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
param_groups = copy.deepcopy(self.param_groups_backup)
|
param_groups = [
|
||||||
|
{**group, "params": group_info["params"]}
|
||||||
|
for group, group_info in zip(self.optim.param_groups, self.param_groups_backup)
|
||||||
|
]
|
||||||
|
|
||||||
# To be compatible with pytorch checkpointing,
|
# To be compatible with pytorch checkpointing,
|
||||||
# store extra hyperparameters used by pytorch Adam optimizer.
|
# store extra hyperparameters used by pytorch Adam optimizer.
|
||||||
|
|
|
@ -1,7 +1,10 @@
|
||||||
import os
|
import os
|
||||||
|
import time
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
|
from pathlib import Path
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
|
from .base_extension import _Extension
|
||||||
from .cpp_extension import _CppExtension
|
from .cpp_extension import _CppExtension
|
||||||
from .utils import check_pytorch_version, check_system_pytorch_cuda_match, set_cuda_arch_list
|
from .utils import check_pytorch_version, check_system_pytorch_cuda_match, set_cuda_arch_list
|
||||||
|
|
||||||
|
|
|
@ -97,7 +97,7 @@ def exam_state_dict(placement_config, shard: bool, model_name: str, size_per_sha
|
||||||
new_model = model_fn()
|
new_model = model_fn()
|
||||||
optimizer = HybridAdam(model.parameters(), lr=0.001)
|
optimizer = HybridAdam(model.parameters(), lr=0.001)
|
||||||
model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion)
|
model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion)
|
||||||
new_optimizer = HybridAdam(new_model.parameters(), lr=0.001)
|
new_optimizer = HybridAdam(new_model.parameters(), lr=0.01)
|
||||||
new_model, new_optimizer, criterion, _, _ = booster.boost(new_model, new_optimizer, criterion)
|
new_model, new_optimizer, criterion, _, _ = booster.boost(new_model, new_optimizer, criterion)
|
||||||
|
|
||||||
data = data_gen_fn()
|
data = data_gen_fn()
|
||||||
|
@ -109,6 +109,8 @@ def exam_state_dict(placement_config, shard: bool, model_name: str, size_per_sha
|
||||||
|
|
||||||
booster.backward(loss, optimizer)
|
booster.backward(loss, optimizer)
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
for group in optimizer.param_groups:
|
||||||
|
group["lr"] = 0.1
|
||||||
|
|
||||||
with shared_tempdir() as tempdir:
|
with shared_tempdir() as tempdir:
|
||||||
model_ckpt_path = f"{tempdir}/model"
|
model_ckpt_path = f"{tempdir}/model"
|
||||||
|
@ -127,6 +129,8 @@ def exam_state_dict(placement_config, shard: bool, model_name: str, size_per_sha
|
||||||
check_state_dict_equal(
|
check_state_dict_equal(
|
||||||
optimizer.state_dict(only_rank_0=False), new_optimizer.state_dict(only_rank_0=False), False
|
optimizer.state_dict(only_rank_0=False), new_optimizer.state_dict(only_rank_0=False), False
|
||||||
)
|
)
|
||||||
|
for group in new_optimizer.param_groups:
|
||||||
|
assert group["lr"] == 0.1
|
||||||
|
|
||||||
# Check the new model/optimizer can successfully run.
|
# Check the new model/optimizer can successfully run.
|
||||||
data = data_gen_fn()
|
data = data_gen_fn()
|
||||||
|
|
|
@ -83,7 +83,8 @@ def exam_state_dict(shard: bool, model_name: str, size_per_shard: int, test_conf
|
||||||
optimizer.backward(loss)
|
optimizer.backward(loss)
|
||||||
|
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
for group in optimizer.param_groups:
|
||||||
|
group["lr"] = 0.1
|
||||||
with shared_tempdir() as tempdir:
|
with shared_tempdir() as tempdir:
|
||||||
model_ckpt_path = f"{tempdir}/model"
|
model_ckpt_path = f"{tempdir}/model"
|
||||||
optimizer_ckpt_path = f"{tempdir}/optimizer"
|
optimizer_ckpt_path = f"{tempdir}/optimizer"
|
||||||
|
|
Loading…
Reference in New Issue