mirror of https://github.com/hpcaitech/ColossalAI
[hotfix] fix the bug of repeatedly storing param group (#4951)
parent
be82b5d4ca
commit
c040d70aa0
|
@ -150,9 +150,10 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
|
||||||
# Preparing file paths and index file.
|
# Preparing file paths and index file.
|
||||||
states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix)
|
states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix)
|
||||||
index_file = CheckpointIndexFile(checkpoint)
|
index_file = CheckpointIndexFile(checkpoint)
|
||||||
|
index_file.append_meta_data("param_groups", param_group_file)
|
||||||
|
|
||||||
# Store the information of param groups to param_group_file.
|
# Store the information of param groups to param_group_file.
|
||||||
index_file.append_meta_data("param_groups", param_group_file)
|
if self.coordinator.is_master():
|
||||||
group_file_path = os.path.join(checkpoint, param_group_file)
|
group_file_path = os.path.join(checkpoint, param_group_file)
|
||||||
param_groups = optimizer.get_param_groups_for_saving()
|
param_groups = optimizer.get_param_groups_for_saving()
|
||||||
torch.save(param_groups, group_file_path)
|
torch.save(param_groups, group_file_path)
|
||||||
|
@ -161,13 +162,12 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
|
||||||
state_dict_shard = optimizer.state_shard(prefix=prefix, max_shard_size=size_per_shard, only_rank_0=True)
|
state_dict_shard = optimizer.state_shard(prefix=prefix, max_shard_size=size_per_shard, only_rank_0=True)
|
||||||
|
|
||||||
# Save shards of optimizer states.
|
# Save shards of optimizer states.
|
||||||
is_master = self.coordinator.is_master()
|
|
||||||
total_size = save_state_dict_shards(
|
total_size = save_state_dict_shards(
|
||||||
sharded_state_dict=state_dict_shard,
|
sharded_state_dict=state_dict_shard,
|
||||||
checkpoint=checkpoint,
|
checkpoint=checkpoint,
|
||||||
index_file=index_file,
|
index_file=index_file,
|
||||||
base_filename=states_name,
|
base_filename=states_name,
|
||||||
is_master=is_master,
|
is_master=self.coordinator.is_master(),
|
||||||
use_safetensors=False,
|
use_safetensors=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -119,9 +119,10 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
|
||||||
# Preparing file paths and index file.
|
# Preparing file paths and index file.
|
||||||
states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix)
|
states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix)
|
||||||
index_file = CheckpointIndexFile(checkpoint)
|
index_file = CheckpointIndexFile(checkpoint)
|
||||||
|
index_file.append_meta_data("param_groups", param_group_file)
|
||||||
|
|
||||||
# Store the information of param groups to param_group_file.
|
# Store the information of param groups to param_group_file.
|
||||||
index_file.append_meta_data("param_groups", param_group_file)
|
if self.coordinator.is_master():
|
||||||
group_file_path = os.path.join(checkpoint, param_group_file)
|
group_file_path = os.path.join(checkpoint, param_group_file)
|
||||||
save_param_groups(state_dict, group_file_path)
|
save_param_groups(state_dict, group_file_path)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue