Browse Source

[checkpointio] fix gemini and hybrid parallel optim checkpoint (#5347)

* [checkpointio] fix hybrid parallel optim checkpoint

* [extension] fix cuda extension

* [checkpointio] fix gemini optimizer checkpoint

* polish code
pull/5355/head
Hongxin Liu 10 months ago committed by GitHub
parent
commit
ffffc32dc7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 26
      colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py
  2. 5
      colossalai/zero/gemini/gemini_optimizer.py
  3. 3
      extensions/cuda_extension.py
  4. 6
      tests/test_checkpoint_io/test_gemini_checkpoint_io.py
  5. 3
      tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py

26
colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py

@ -1,7 +1,7 @@
import copy
from functools import reduce
import logging
import os
from functools import reduce
from pathlib import Path
from shutil import rmtree
from typing import Dict, Iterator, Optional, OrderedDict, Tuple
@ -445,7 +445,11 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
# Store param groups.
index_file.append_meta_data("param_groups", param_group_file)
group_file_path = os.path.join(checkpoint, param_group_file)
save_param_groups(optimizer.param_info, group_file_path)
param_groups = [
{**group, "params": group_info["params"]}
for group, group_info in zip(optimizer.param_groups, optimizer.param_info["param_groups"])
]
save_param_groups({"param_groups": param_groups}, group_file_path)
# Store index file.
index_file.append_meta_data("total_size", total_size)
index_file.write_index_file(save_index_file)
@ -504,7 +508,11 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
# Store param groups.
final_index_file.append_meta_data("param_groups", param_group_file)
group_file_path = os.path.join(checkpoint, param_group_file)
save_param_groups(optimizer.param_info, group_file_path)
param_groups = [
{**group, "params": group_info["params"]}
for group, group_info in zip(optimizer.param_groups, optimizer.param_info["param_groups"])
]
save_param_groups({"param_groups": param_groups}, group_file_path)
final_index_file.write_index_file(final_index_file_path)
rmtree(tmp_index_file_folder)
@ -718,7 +726,11 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
if self.pp_size == 1:
# When pipeline is not used, let master rank directly save the collected state_dict.
state_dict = {"param_groups": optimizer.param_info["param_groups"], "state": local_states}
param_groups = [
{**group, "params": group_info["params"]}
for group, group_info in zip(optimizer.param_groups, optimizer.param_info["param_groups"])
]
state_dict = {"param_groups": param_groups, "state": local_states}
if self.coordinator.is_master():
save_state_dict(state_dict, checkpoint, use_safetensors=False)
else:
@ -729,7 +741,11 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
# Only the master rank do the saving.
if self.coordinator.is_master():
state_dict = {"param_groups": optimizer.param_info["param_groups"], "state": dict()}
param_groups = [
{**group, "params": group_info["params"]}
for group, group_info in zip(optimizer.param_groups, optimizer.param_info["param_groups"])
]
state_dict = {"param_groups": param_groups, "state": dict()}
for _states in states_list:
state_dict["state"].update(_states)
save_state_dict(state_dict, checkpoint, use_safetensors=False)

5
colossalai/zero/gemini/gemini_optimizer.py

@ -621,7 +621,10 @@ class GeminiOptimizer(OptimizerWrapper):
Return the param_groups in Pytorch format when saving to checkpoint.
"""
param_groups = copy.deepcopy(self.param_groups_backup)
param_groups = [
{**group, "params": group_info["params"]}
for group, group_info in zip(self.optim.param_groups, self.param_groups_backup)
]
# To be compatible with pytorch checkpointing,
# store extra hyperparameters used by pytorch Adam optimizer.

3
extensions/cuda_extension.py

@ -1,7 +1,10 @@
import os
import time
from abc import abstractmethod
from pathlib import Path
from typing import List
from .base_extension import _Extension
from .cpp_extension import _CppExtension
from .utils import check_pytorch_version, check_system_pytorch_cuda_match, set_cuda_arch_list

6
tests/test_checkpoint_io/test_gemini_checkpoint_io.py

@ -97,7 +97,7 @@ def exam_state_dict(placement_config, shard: bool, model_name: str, size_per_sha
new_model = model_fn()
optimizer = HybridAdam(model.parameters(), lr=0.001)
model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion)
new_optimizer = HybridAdam(new_model.parameters(), lr=0.001)
new_optimizer = HybridAdam(new_model.parameters(), lr=0.01)
new_model, new_optimizer, criterion, _, _ = booster.boost(new_model, new_optimizer, criterion)
data = data_gen_fn()
@ -109,6 +109,8 @@ def exam_state_dict(placement_config, shard: bool, model_name: str, size_per_sha
booster.backward(loss, optimizer)
optimizer.step()
for group in optimizer.param_groups:
group["lr"] = 0.1
with shared_tempdir() as tempdir:
model_ckpt_path = f"{tempdir}/model"
@ -127,6 +129,8 @@ def exam_state_dict(placement_config, shard: bool, model_name: str, size_per_sha
check_state_dict_equal(
optimizer.state_dict(only_rank_0=False), new_optimizer.state_dict(only_rank_0=False), False
)
for group in new_optimizer.param_groups:
assert group["lr"] == 0.1
# Check the new model/optimizer can successfully run.
data = data_gen_fn()

3
tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py

@ -83,7 +83,8 @@ def exam_state_dict(shard: bool, model_name: str, size_per_shard: int, test_conf
optimizer.backward(loss)
optimizer.step()
for group in optimizer.param_groups:
group["lr"] = 0.1
with shared_tempdir() as tempdir:
model_ckpt_path = f"{tempdir}/model"
optimizer_ckpt_path = f"{tempdir}/optimizer"

Loading…
Cancel
Save