[checkpointio] fix gemini and hybrid parallel optim checkpoint (#5347)

* [checkpointio] fix hybrid parallel optim checkpoint

* [extension] fix cuda extension

* [checkpointio] fix gemini optimizer checkpoint

* polish code
pull/5355/head
Hongxin Liu 2024-02-01 16:13:06 +08:00 committed by GitHub
parent c5239840e6
commit ffffc32dc7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 35 additions and 8 deletions

View File

@ -1,7 +1,7 @@
import copy import copy
from functools import reduce
import logging import logging
import os import os
from functools import reduce
from pathlib import Path from pathlib import Path
from shutil import rmtree from shutil import rmtree
from typing import Dict, Iterator, Optional, OrderedDict, Tuple from typing import Dict, Iterator, Optional, OrderedDict, Tuple
@ -445,7 +445,11 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
# Store param groups. # Store param groups.
index_file.append_meta_data("param_groups", param_group_file) index_file.append_meta_data("param_groups", param_group_file)
group_file_path = os.path.join(checkpoint, param_group_file) group_file_path = os.path.join(checkpoint, param_group_file)
save_param_groups(optimizer.param_info, group_file_path) param_groups = [
{**group, "params": group_info["params"]}
for group, group_info in zip(optimizer.param_groups, optimizer.param_info["param_groups"])
]
save_param_groups({"param_groups": param_groups}, group_file_path)
# Store index file. # Store index file.
index_file.append_meta_data("total_size", total_size) index_file.append_meta_data("total_size", total_size)
index_file.write_index_file(save_index_file) index_file.write_index_file(save_index_file)
@ -504,7 +508,11 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
# Store param groups. # Store param groups.
final_index_file.append_meta_data("param_groups", param_group_file) final_index_file.append_meta_data("param_groups", param_group_file)
group_file_path = os.path.join(checkpoint, param_group_file) group_file_path = os.path.join(checkpoint, param_group_file)
save_param_groups(optimizer.param_info, group_file_path) param_groups = [
{**group, "params": group_info["params"]}
for group, group_info in zip(optimizer.param_groups, optimizer.param_info["param_groups"])
]
save_param_groups({"param_groups": param_groups}, group_file_path)
final_index_file.write_index_file(final_index_file_path) final_index_file.write_index_file(final_index_file_path)
rmtree(tmp_index_file_folder) rmtree(tmp_index_file_folder)
@ -718,7 +726,11 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
if self.pp_size == 1: if self.pp_size == 1:
# When pipeline is not used, let master rank directly save the collected state_dict. # When pipeline is not used, let master rank directly save the collected state_dict.
state_dict = {"param_groups": optimizer.param_info["param_groups"], "state": local_states} param_groups = [
{**group, "params": group_info["params"]}
for group, group_info in zip(optimizer.param_groups, optimizer.param_info["param_groups"])
]
state_dict = {"param_groups": param_groups, "state": local_states}
if self.coordinator.is_master(): if self.coordinator.is_master():
save_state_dict(state_dict, checkpoint, use_safetensors=False) save_state_dict(state_dict, checkpoint, use_safetensors=False)
else: else:
@ -729,7 +741,11 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
# Only the master rank do the saving. # Only the master rank do the saving.
if self.coordinator.is_master(): if self.coordinator.is_master():
state_dict = {"param_groups": optimizer.param_info["param_groups"], "state": dict()} param_groups = [
{**group, "params": group_info["params"]}
for group, group_info in zip(optimizer.param_groups, optimizer.param_info["param_groups"])
]
state_dict = {"param_groups": param_groups, "state": dict()}
for _states in states_list: for _states in states_list:
state_dict["state"].update(_states) state_dict["state"].update(_states)
save_state_dict(state_dict, checkpoint, use_safetensors=False) save_state_dict(state_dict, checkpoint, use_safetensors=False)

View File

@ -621,7 +621,10 @@ class GeminiOptimizer(OptimizerWrapper):
Return the param_groups in Pytorch format when saving to checkpoint. Return the param_groups in Pytorch format when saving to checkpoint.
""" """
param_groups = copy.deepcopy(self.param_groups_backup) param_groups = [
{**group, "params": group_info["params"]}
for group, group_info in zip(self.optim.param_groups, self.param_groups_backup)
]
# To be compatible with pytorch checkpointing, # To be compatible with pytorch checkpointing,
# store extra hyperparameters used by pytorch Adam optimizer. # store extra hyperparameters used by pytorch Adam optimizer.

View File

@ -1,7 +1,10 @@
import os import os
import time
from abc import abstractmethod from abc import abstractmethod
from pathlib import Path
from typing import List from typing import List
from .base_extension import _Extension
from .cpp_extension import _CppExtension from .cpp_extension import _CppExtension
from .utils import check_pytorch_version, check_system_pytorch_cuda_match, set_cuda_arch_list from .utils import check_pytorch_version, check_system_pytorch_cuda_match, set_cuda_arch_list

View File

@ -97,7 +97,7 @@ def exam_state_dict(placement_config, shard: bool, model_name: str, size_per_sha
new_model = model_fn() new_model = model_fn()
optimizer = HybridAdam(model.parameters(), lr=0.001) optimizer = HybridAdam(model.parameters(), lr=0.001)
model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion) model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion)
new_optimizer = HybridAdam(new_model.parameters(), lr=0.001) new_optimizer = HybridAdam(new_model.parameters(), lr=0.01)
new_model, new_optimizer, criterion, _, _ = booster.boost(new_model, new_optimizer, criterion) new_model, new_optimizer, criterion, _, _ = booster.boost(new_model, new_optimizer, criterion)
data = data_gen_fn() data = data_gen_fn()
@ -109,6 +109,8 @@ def exam_state_dict(placement_config, shard: bool, model_name: str, size_per_sha
booster.backward(loss, optimizer) booster.backward(loss, optimizer)
optimizer.step() optimizer.step()
for group in optimizer.param_groups:
group["lr"] = 0.1
with shared_tempdir() as tempdir: with shared_tempdir() as tempdir:
model_ckpt_path = f"{tempdir}/model" model_ckpt_path = f"{tempdir}/model"
@ -127,6 +129,8 @@ def exam_state_dict(placement_config, shard: bool, model_name: str, size_per_sha
check_state_dict_equal( check_state_dict_equal(
optimizer.state_dict(only_rank_0=False), new_optimizer.state_dict(only_rank_0=False), False optimizer.state_dict(only_rank_0=False), new_optimizer.state_dict(only_rank_0=False), False
) )
for group in new_optimizer.param_groups:
assert group["lr"] == 0.1
# Check the new model/optimizer can successfully run. # Check the new model/optimizer can successfully run.
data = data_gen_fn() data = data_gen_fn()

View File

@ -83,7 +83,8 @@ def exam_state_dict(shard: bool, model_name: str, size_per_shard: int, test_conf
optimizer.backward(loss) optimizer.backward(loss)
optimizer.step() optimizer.step()
for group in optimizer.param_groups:
group["lr"] = 0.1
with shared_tempdir() as tempdir: with shared_tempdir() as tempdir:
model_ckpt_path = f"{tempdir}/model" model_ckpt_path = f"{tempdir}/model"
optimizer_ckpt_path = f"{tempdir}/optimizer" optimizer_ckpt_path = f"{tempdir}/optimizer"