mirror of https://github.com/hpcaitech/ColossalAI
[checkpointio] fix gemini and hybrid parallel optim checkpoint (#5347)
* [checkpointio] fix hybrid parallel optim checkpoint * [extension] fix cuda extension * [checkpointio] fix gemini optimizer checkpoint * polish codepull/5355/head
parent
c5239840e6
commit
ffffc32dc7
|
@ -1,7 +1,7 @@
|
|||
import copy
|
||||
from functools import reduce
|
||||
import logging
|
||||
import os
|
||||
from functools import reduce
|
||||
from pathlib import Path
|
||||
from shutil import rmtree
|
||||
from typing import Dict, Iterator, Optional, OrderedDict, Tuple
|
||||
|
@ -445,7 +445,11 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
|||
# Store param groups.
|
||||
index_file.append_meta_data("param_groups", param_group_file)
|
||||
group_file_path = os.path.join(checkpoint, param_group_file)
|
||||
save_param_groups(optimizer.param_info, group_file_path)
|
||||
param_groups = [
|
||||
{**group, "params": group_info["params"]}
|
||||
for group, group_info in zip(optimizer.param_groups, optimizer.param_info["param_groups"])
|
||||
]
|
||||
save_param_groups({"param_groups": param_groups}, group_file_path)
|
||||
# Store index file.
|
||||
index_file.append_meta_data("total_size", total_size)
|
||||
index_file.write_index_file(save_index_file)
|
||||
|
@ -504,7 +508,11 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
|||
# Store param groups.
|
||||
final_index_file.append_meta_data("param_groups", param_group_file)
|
||||
group_file_path = os.path.join(checkpoint, param_group_file)
|
||||
save_param_groups(optimizer.param_info, group_file_path)
|
||||
param_groups = [
|
||||
{**group, "params": group_info["params"]}
|
||||
for group, group_info in zip(optimizer.param_groups, optimizer.param_info["param_groups"])
|
||||
]
|
||||
save_param_groups({"param_groups": param_groups}, group_file_path)
|
||||
|
||||
final_index_file.write_index_file(final_index_file_path)
|
||||
rmtree(tmp_index_file_folder)
|
||||
|
@ -718,7 +726,11 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
|||
|
||||
if self.pp_size == 1:
|
||||
# When pipeline is not used, let master rank directly save the collected state_dict.
|
||||
state_dict = {"param_groups": optimizer.param_info["param_groups"], "state": local_states}
|
||||
param_groups = [
|
||||
{**group, "params": group_info["params"]}
|
||||
for group, group_info in zip(optimizer.param_groups, optimizer.param_info["param_groups"])
|
||||
]
|
||||
state_dict = {"param_groups": param_groups, "state": local_states}
|
||||
if self.coordinator.is_master():
|
||||
save_state_dict(state_dict, checkpoint, use_safetensors=False)
|
||||
else:
|
||||
|
@ -729,7 +741,11 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
|||
|
||||
# Only the master rank do the saving.
|
||||
if self.coordinator.is_master():
|
||||
state_dict = {"param_groups": optimizer.param_info["param_groups"], "state": dict()}
|
||||
param_groups = [
|
||||
{**group, "params": group_info["params"]}
|
||||
for group, group_info in zip(optimizer.param_groups, optimizer.param_info["param_groups"])
|
||||
]
|
||||
state_dict = {"param_groups": param_groups, "state": dict()}
|
||||
for _states in states_list:
|
||||
state_dict["state"].update(_states)
|
||||
save_state_dict(state_dict, checkpoint, use_safetensors=False)
|
||||
|
|
|
@ -621,7 +621,10 @@ class GeminiOptimizer(OptimizerWrapper):
|
|||
Return the param_groups in Pytorch format when saving to checkpoint.
|
||||
"""
|
||||
|
||||
param_groups = copy.deepcopy(self.param_groups_backup)
|
||||
param_groups = [
|
||||
{**group, "params": group_info["params"]}
|
||||
for group, group_info in zip(self.optim.param_groups, self.param_groups_backup)
|
||||
]
|
||||
|
||||
# To be compatible with pytorch checkpointing,
|
||||
# store extra hyperparameters used by pytorch Adam optimizer.
|
||||
|
|
|
@ -1,7 +1,10 @@
|
|||
import os
|
||||
import time
|
||||
from abc import abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
from .base_extension import _Extension
|
||||
from .cpp_extension import _CppExtension
|
||||
from .utils import check_pytorch_version, check_system_pytorch_cuda_match, set_cuda_arch_list
|
||||
|
||||
|
|
|
@ -97,7 +97,7 @@ def exam_state_dict(placement_config, shard: bool, model_name: str, size_per_sha
|
|||
new_model = model_fn()
|
||||
optimizer = HybridAdam(model.parameters(), lr=0.001)
|
||||
model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion)
|
||||
new_optimizer = HybridAdam(new_model.parameters(), lr=0.001)
|
||||
new_optimizer = HybridAdam(new_model.parameters(), lr=0.01)
|
||||
new_model, new_optimizer, criterion, _, _ = booster.boost(new_model, new_optimizer, criterion)
|
||||
|
||||
data = data_gen_fn()
|
||||
|
@ -109,6 +109,8 @@ def exam_state_dict(placement_config, shard: bool, model_name: str, size_per_sha
|
|||
|
||||
booster.backward(loss, optimizer)
|
||||
optimizer.step()
|
||||
for group in optimizer.param_groups:
|
||||
group["lr"] = 0.1
|
||||
|
||||
with shared_tempdir() as tempdir:
|
||||
model_ckpt_path = f"{tempdir}/model"
|
||||
|
@ -127,6 +129,8 @@ def exam_state_dict(placement_config, shard: bool, model_name: str, size_per_sha
|
|||
check_state_dict_equal(
|
||||
optimizer.state_dict(only_rank_0=False), new_optimizer.state_dict(only_rank_0=False), False
|
||||
)
|
||||
for group in new_optimizer.param_groups:
|
||||
assert group["lr"] == 0.1
|
||||
|
||||
# Check the new model/optimizer can successfully run.
|
||||
data = data_gen_fn()
|
||||
|
|
|
@ -83,7 +83,8 @@ def exam_state_dict(shard: bool, model_name: str, size_per_shard: int, test_conf
|
|||
optimizer.backward(loss)
|
||||
|
||||
optimizer.step()
|
||||
|
||||
for group in optimizer.param_groups:
|
||||
group["lr"] = 0.1
|
||||
with shared_tempdir() as tempdir:
|
||||
model_ckpt_path = f"{tempdir}/model"
|
||||
optimizer_ckpt_path = f"{tempdir}/optimizer"
|
||||
|
|
Loading…
Reference in New Issue