mirror of https://github.com/hpcaitech/ColossalAI
[gemini] fixed the gemini checkpoint io (#3934)
parent
b3ab7fbabf
commit
71fe52769c
|
@ -99,8 +99,11 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
|
|||
save_state_dict(shard, checkpoint_file_path, use_safetensors)
|
||||
|
||||
index_file.append_meta_data("total_size", total_size)
|
||||
index_file.write_index_file(save_index_file)
|
||||
logging.info(f"The model is going to be split to checkpoint shards. "
|
||||
|
||||
# only save the index file on the master rank
|
||||
if self.coordinator.is_master():
|
||||
index_file.write_index_file(save_index_file)
|
||||
logging.info(f"The model is split into checkpoint shards. "
|
||||
f"You can find where each parameters has been saved in the "
|
||||
f"index located at {save_index_file}.")
|
||||
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
import json
|
||||
from pathlib import Path
|
||||
from typing import Any, List, Union
|
||||
import os
|
||||
import json
|
||||
from collections import OrderedDict
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Union
|
||||
|
||||
from .utils import is_dtensor_checkpoint
|
||||
|
||||
|
@ -22,8 +22,10 @@ class CheckpointIndexFile:
|
|||
|
||||
def __init__(self, root_path=None) -> None:
|
||||
self.root_path = root_path
|
||||
self.metadata: dict = dict()
|
||||
self.weight_map: dict = dict()
|
||||
|
||||
# use ordered dict to preserve the tensor checkpoint order
|
||||
self.metadata: Dict = OrderedDict()
|
||||
self.weight_map: Dict = OrderedDict()
|
||||
|
||||
@staticmethod
|
||||
def from_file(index_path: Union[str, Path]):
|
||||
|
@ -150,13 +152,13 @@ class CheckpointIndexFile:
|
|||
"""
|
||||
ckpt_path = self.weight_map[param_name]
|
||||
return ckpt_path
|
||||
|
||||
|
||||
def get_all_param_names(self):
|
||||
"""
|
||||
Get all the weight keys.
|
||||
"""
|
||||
return list(self.weight_map.keys())
|
||||
|
||||
|
||||
def write_index_file(self, save_index_file):
|
||||
"""
|
||||
Write index file.
|
||||
|
@ -164,5 +166,5 @@ class CheckpointIndexFile:
|
|||
save_index_file = os.path.join(self.root_path, save_index_file)
|
||||
index = {"metadata": self.metadata, "weight_map": self.weight_map}
|
||||
with open(save_index_file, "w", encoding="utf-8") as f:
|
||||
content = json.dumps(index, indent=2, sort_keys=True) + "\n"
|
||||
content = json.dumps(index, indent=2) + "\n"
|
||||
f.write(content)
|
||||
|
|
|
@ -716,7 +716,10 @@ class _StateDictSharder:
|
|||
tensor_size = calculate_tensor_size(tensor)
|
||||
ret_block = None
|
||||
ret_block_size = 0
|
||||
if self.current_block_size + tensor_size > self.max_shard_size:
|
||||
|
||||
# before we return the current block and create a new block,
|
||||
# we need to ensure that the current block is not empty
|
||||
if self.current_block_size + tensor_size > self.max_shard_size and self.current_block_size > 0:
|
||||
ret_block = self.current_block
|
||||
ret_block_size = self.current_block_size
|
||||
self.current_block = OrderedDict()
|
||||
|
|
Loading…
Reference in New Issue