diff --git a/colossalai/booster/plugin/gemini_plugin.py b/colossalai/booster/plugin/gemini_plugin.py index 46714fe1c..4a7efc165 100644 --- a/colossalai/booster/plugin/gemini_plugin.py +++ b/colossalai/booster/plugin/gemini_plugin.py @@ -99,8 +99,11 @@ class GeminiCheckpointIO(GeneralCheckpointIO): save_state_dict(shard, checkpoint_file_path, use_safetensors) index_file.append_meta_data("total_size", total_size) - index_file.write_index_file(save_index_file) - logging.info(f"The model is going to be split to checkpoint shards. " + + # only save the index file on the master rank + if self.coordinator.is_master(): + index_file.write_index_file(save_index_file) + logging.info(f"The model is split into checkpoint shards. " f"You can find where each parameters has been saved in the " f"index located at {save_index_file}.") diff --git a/colossalai/checkpoint_io/index_file.py b/colossalai/checkpoint_io/index_file.py index 334ecbc04..a41cc482e 100644 --- a/colossalai/checkpoint_io/index_file.py +++ b/colossalai/checkpoint_io/index_file.py @@ -1,8 +1,8 @@ import json -from pathlib import Path -from typing import Any, List, Union import os -import json +from collections import OrderedDict +from pathlib import Path +from typing import Any, Dict, List, Union from .utils import is_dtensor_checkpoint @@ -22,8 +22,10 @@ class CheckpointIndexFile: def __init__(self, root_path=None) -> None: self.root_path = root_path - self.metadata: dict = dict() - self.weight_map: dict = dict() + + # use ordered dict to preserve the tensor checkpoint order + self.metadata: Dict = OrderedDict() + self.weight_map: Dict = OrderedDict() @staticmethod def from_file(index_path: Union[str, Path]): @@ -150,13 +152,13 @@ class CheckpointIndexFile: """ ckpt_path = self.weight_map[param_name] return ckpt_path - + def get_all_param_names(self): """ Get all the weight keys. """ return list(self.weight_map.keys()) - + def write_index_file(self, save_index_file): """ Write index file. @@ -164,5 +166,5 @@ class CheckpointIndexFile: save_index_file = os.path.join(self.root_path, save_index_file) index = {"metadata": self.metadata, "weight_map": self.weight_map} with open(save_index_file, "w", encoding="utf-8") as f: - content = json.dumps(index, indent=2, sort_keys=True) + "\n" + content = json.dumps(index, indent=2) + "\n" f.write(content) diff --git a/colossalai/zero/gemini/gemini_ddp.py b/colossalai/zero/gemini/gemini_ddp.py index 7e23fdb42..094320c4a 100644 --- a/colossalai/zero/gemini/gemini_ddp.py +++ b/colossalai/zero/gemini/gemini_ddp.py @@ -716,7 +716,10 @@ class _StateDictSharder: tensor_size = calculate_tensor_size(tensor) ret_block = None ret_block_size = 0 - if self.current_block_size + tensor_size > self.max_shard_size: + + # before we return the current block and create a new block, + # we need to ensure that the current block is not empty + if self.current_block_size + tensor_size > self.max_shard_size and self.current_block_size > 0: ret_block = self.current_block ret_block_size = self.current_block_size self.current_block = OrderedDict()