|
|
|
@ -1,7 +1,7 @@
|
|
|
|
|
import copy |
|
|
|
|
from functools import reduce |
|
|
|
|
import logging |
|
|
|
|
import os |
|
|
|
|
from functools import reduce |
|
|
|
|
from pathlib import Path |
|
|
|
|
from shutil import rmtree |
|
|
|
|
from typing import Dict, Iterator, Optional, OrderedDict, Tuple |
|
|
|
@ -445,7 +445,11 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
|
|
|
|
# Store param groups. |
|
|
|
|
index_file.append_meta_data("param_groups", param_group_file) |
|
|
|
|
group_file_path = os.path.join(checkpoint, param_group_file) |
|
|
|
|
save_param_groups(optimizer.param_info, group_file_path) |
|
|
|
|
param_groups = [ |
|
|
|
|
{**group, "params": group_info["params"]} |
|
|
|
|
for group, group_info in zip(optimizer.param_groups, optimizer.param_info["param_groups"]) |
|
|
|
|
] |
|
|
|
|
save_param_groups({"param_groups": param_groups}, group_file_path) |
|
|
|
|
# Store index file. |
|
|
|
|
index_file.append_meta_data("total_size", total_size) |
|
|
|
|
index_file.write_index_file(save_index_file) |
|
|
|
@ -504,7 +508,11 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
|
|
|
|
# Store param groups. |
|
|
|
|
final_index_file.append_meta_data("param_groups", param_group_file) |
|
|
|
|
group_file_path = os.path.join(checkpoint, param_group_file) |
|
|
|
|
save_param_groups(optimizer.param_info, group_file_path) |
|
|
|
|
param_groups = [ |
|
|
|
|
{**group, "params": group_info["params"]} |
|
|
|
|
for group, group_info in zip(optimizer.param_groups, optimizer.param_info["param_groups"]) |
|
|
|
|
] |
|
|
|
|
save_param_groups({"param_groups": param_groups}, group_file_path) |
|
|
|
|
|
|
|
|
|
final_index_file.write_index_file(final_index_file_path) |
|
|
|
|
rmtree(tmp_index_file_folder) |
|
|
|
@ -718,7 +726,11 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
|
|
|
|
|
|
|
|
|
if self.pp_size == 1: |
|
|
|
|
# When pipeline is not used, let master rank directly save the collected state_dict. |
|
|
|
|
state_dict = {"param_groups": optimizer.param_info["param_groups"], "state": local_states} |
|
|
|
|
param_groups = [ |
|
|
|
|
{**group, "params": group_info["params"]} |
|
|
|
|
for group, group_info in zip(optimizer.param_groups, optimizer.param_info["param_groups"]) |
|
|
|
|
] |
|
|
|
|
state_dict = {"param_groups": param_groups, "state": local_states} |
|
|
|
|
if self.coordinator.is_master(): |
|
|
|
|
save_state_dict(state_dict, checkpoint, use_safetensors=False) |
|
|
|
|
else: |
|
|
|
@ -729,7 +741,11 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
|
|
|
|
|
|
|
|
|
# Only the master rank do the saving. |
|
|
|
|
if self.coordinator.is_master(): |
|
|
|
|
state_dict = {"param_groups": optimizer.param_info["param_groups"], "state": dict()} |
|
|
|
|
param_groups = [ |
|
|
|
|
{**group, "params": group_info["params"]} |
|
|
|
|
for group, group_info in zip(optimizer.param_groups, optimizer.param_info["param_groups"]) |
|
|
|
|
] |
|
|
|
|
state_dict = {"param_groups": param_groups, "state": dict()} |
|
|
|
|
for _states in states_list: |
|
|
|
|
state_dict["state"].update(_states) |
|
|
|
|
save_state_dict(state_dict, checkpoint, use_safetensors=False) |
|
|
|
|