mirror of https://github.com/hpcaitech/ColossalAI
[gemini] fixed the gemini checkpoint io (#3934)
parent
b3ab7fbabf
commit
71fe52769c
|
@ -99,8 +99,11 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
|
||||||
save_state_dict(shard, checkpoint_file_path, use_safetensors)
|
save_state_dict(shard, checkpoint_file_path, use_safetensors)
|
||||||
|
|
||||||
index_file.append_meta_data("total_size", total_size)
|
index_file.append_meta_data("total_size", total_size)
|
||||||
index_file.write_index_file(save_index_file)
|
|
||||||
logging.info(f"The model is going to be split to checkpoint shards. "
|
# only save the index file on the master rank
|
||||||
|
if self.coordinator.is_master():
|
||||||
|
index_file.write_index_file(save_index_file)
|
||||||
|
logging.info(f"The model is split into checkpoint shards. "
|
||||||
f"You can find where each parameters has been saved in the "
|
f"You can find where each parameters has been saved in the "
|
||||||
f"index located at {save_index_file}.")
|
f"index located at {save_index_file}.")
|
||||||
|
|
||||||
|
|
|
@ -1,8 +1,8 @@
|
||||||
import json
|
import json
|
||||||
from pathlib import Path
|
|
||||||
from typing import Any, List, Union
|
|
||||||
import os
|
import os
|
||||||
import json
|
from collections import OrderedDict
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Dict, List, Union
|
||||||
|
|
||||||
from .utils import is_dtensor_checkpoint
|
from .utils import is_dtensor_checkpoint
|
||||||
|
|
||||||
|
@ -22,8 +22,10 @@ class CheckpointIndexFile:
|
||||||
|
|
||||||
def __init__(self, root_path=None) -> None:
|
def __init__(self, root_path=None) -> None:
|
||||||
self.root_path = root_path
|
self.root_path = root_path
|
||||||
self.metadata: dict = dict()
|
|
||||||
self.weight_map: dict = dict()
|
# use ordered dict to preserve the tensor checkpoint order
|
||||||
|
self.metadata: Dict = OrderedDict()
|
||||||
|
self.weight_map: Dict = OrderedDict()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_file(index_path: Union[str, Path]):
|
def from_file(index_path: Union[str, Path]):
|
||||||
|
@ -150,13 +152,13 @@ class CheckpointIndexFile:
|
||||||
"""
|
"""
|
||||||
ckpt_path = self.weight_map[param_name]
|
ckpt_path = self.weight_map[param_name]
|
||||||
return ckpt_path
|
return ckpt_path
|
||||||
|
|
||||||
def get_all_param_names(self):
|
def get_all_param_names(self):
|
||||||
"""
|
"""
|
||||||
Get all the weight keys.
|
Get all the weight keys.
|
||||||
"""
|
"""
|
||||||
return list(self.weight_map.keys())
|
return list(self.weight_map.keys())
|
||||||
|
|
||||||
def write_index_file(self, save_index_file):
|
def write_index_file(self, save_index_file):
|
||||||
"""
|
"""
|
||||||
Write index file.
|
Write index file.
|
||||||
|
@ -164,5 +166,5 @@ class CheckpointIndexFile:
|
||||||
save_index_file = os.path.join(self.root_path, save_index_file)
|
save_index_file = os.path.join(self.root_path, save_index_file)
|
||||||
index = {"metadata": self.metadata, "weight_map": self.weight_map}
|
index = {"metadata": self.metadata, "weight_map": self.weight_map}
|
||||||
with open(save_index_file, "w", encoding="utf-8") as f:
|
with open(save_index_file, "w", encoding="utf-8") as f:
|
||||||
content = json.dumps(index, indent=2, sort_keys=True) + "\n"
|
content = json.dumps(index, indent=2) + "\n"
|
||||||
f.write(content)
|
f.write(content)
|
||||||
|
|
|
@ -716,7 +716,10 @@ class _StateDictSharder:
|
||||||
tensor_size = calculate_tensor_size(tensor)
|
tensor_size = calculate_tensor_size(tensor)
|
||||||
ret_block = None
|
ret_block = None
|
||||||
ret_block_size = 0
|
ret_block_size = 0
|
||||||
if self.current_block_size + tensor_size > self.max_shard_size:
|
|
||||||
|
# before we return the current block and create a new block,
|
||||||
|
# we need to ensure that the current block is not empty
|
||||||
|
if self.current_block_size + tensor_size > self.max_shard_size and self.current_block_size > 0:
|
||||||
ret_block = self.current_block
|
ret_block = self.current_block
|
||||||
ret_block_size = self.current_block_size
|
ret_block_size = self.current_block_size
|
||||||
self.current_block = OrderedDict()
|
self.current_block = OrderedDict()
|
||||||
|
|
Loading…
Reference in New Issue