mirror of https://github.com/hpcaitech/ColossalAI
aibig-modeldata-parallelismdeep-learningdistributed-computingfoundation-modelsheterogeneous-traininghpcinferencelarge-scalemodel-parallelismpipeline-parallelism
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
182 lines
5.6 KiB
182 lines
5.6 KiB
import json |
|
import os |
|
from collections import OrderedDict |
|
from pathlib import Path |
|
from typing import Any, Dict, List, Union |
|
|
|
from .utils import is_dtensor_checkpoint |
|
|
|
__all__ = ["CheckpointIndexFile"] |
|
|
|
|
|
class CheckpointIndexFile: |
|
""" |
|
This class is a data structure to keep the content in the index.json file for sharded checkpoint. |
|
|
|
Example: |
|
>>> index = CheckpointIndexFile.from_file('model.index.json') |
|
>>> index.append_metadata('model_type', 'bert') |
|
>>> index.append_weight_map('bert.embeddings.word_embeddings.weight', 'model_0001-of-0002.bin') |
|
>>> index.export('new_index.json') |
|
""" |
|
|
|
def __init__(self, root_path=None) -> None: |
|
self.root_path = root_path |
|
|
|
# use ordered dict to preserve the tensor checkpoint order |
|
self.metadata: Dict = OrderedDict() |
|
self.weight_map: Dict = OrderedDict() |
|
|
|
@staticmethod |
|
def from_file(index_path: Union[str, Path]): |
|
""" |
|
Create a CheckpointIndexFile object from a json file. |
|
|
|
Args: |
|
index_path (str): path to the json file. |
|
|
|
Returns: |
|
CheckpointIndexFile: CheckpointIndexFile object. |
|
""" |
|
index = CheckpointIndexFile() |
|
index.load(index_path) |
|
return index |
|
|
|
def load(self, json_path: str): |
|
""" |
|
Load the index file from a json file. |
|
|
|
Args: |
|
json_path (str): path to the json file. |
|
""" |
|
# load the json file |
|
with open(json_path, "r") as f: |
|
index = json.load(f) |
|
|
|
# assign attributes if exists |
|
if "metadata" in index: |
|
self.metadata = index["metadata"] |
|
if "weight_map" in index: |
|
self.weight_map = index["weight_map"] |
|
|
|
# assign the root directory for the index file |
|
self.root_path = Path(json_path).absolute().parent |
|
|
|
def export(self, json_path: str): |
|
""" |
|
Export the index file to a json file. |
|
|
|
Args: |
|
json_path (str): path to the json file. |
|
""" |
|
# create the index file |
|
index = dict() |
|
index["metadata"] = self.metadata |
|
index["weight_map"] = self.weight_map |
|
|
|
# export the index file |
|
with open(json_path, "w") as f: |
|
json.dump(index, f, indent=4) |
|
|
|
def append_weight_map(self, param_name: str, shard_file: str): |
|
""" |
|
Append a weight map entry to the index file. |
|
|
|
Args: |
|
param_name (str): name of the parameter. |
|
shard_file (str): name of the shard file. |
|
""" |
|
self.weight_map[param_name] = shard_file |
|
|
|
def append_meta_data(self, name: str, val: Any): |
|
""" |
|
Append a metadata entry to the index file. |
|
|
|
Args: |
|
name (str): name of the metadata. |
|
val (Any): value of the metadata. |
|
""" |
|
self.metadata[name] = val |
|
|
|
def contains_dtensor(self): |
|
""" |
|
Check if the index file contains any distributed tensor. The distributed tensors will be stored in |
|
`dtensor/module.linear.weight.*.bin` or `dtensor/module.linear.weight.*.safetensors` in the weight map. |
|
|
|
Returns: |
|
bool: True if the index file contains any distributed tensor, False otherwise. |
|
""" |
|
for value in self.weight_map.values(): |
|
if value.endswith(".*.bin") or value.endswith(".*.safetensors"): |
|
return True |
|
return False |
|
|
|
def get_checkpoint_filenames(self) -> List[str]: |
|
""" |
|
Get the set of checkpoint filenames in the weight map. |
|
|
|
Returns: |
|
list: checkpoint shard filenames. |
|
""" |
|
# read the checkpoint file list from the json file and get a list of unique file names |
|
checkpoint_files = sorted(list(set(self.weight_map.values()))) |
|
|
|
# get the absolute paths for all checkpoint files |
|
checkpoint_files = [str(self.root_path.joinpath(f)) for f in checkpoint_files] |
|
|
|
dtensor_list = [] |
|
checkpoint_list = [] |
|
|
|
for ckpt_file in checkpoint_files: |
|
if is_dtensor_checkpoint(ckpt_file): |
|
dtensor_list.append(ckpt_file) |
|
else: |
|
checkpoint_list.append(ckpt_file) |
|
|
|
return checkpoint_list, dtensor_list |
|
|
|
def assert_no_dtensor_checkpoint(self): |
|
for val in self.weight_map.values(): |
|
if is_dtensor_checkpoint(val): |
|
raise ValueError(f"Checkpoint file {val} contains distributed tensor") |
|
|
|
def get_checkpoint_file(self, param_name: str) -> str: |
|
""" |
|
Get the checkpoint file name for a parameter. |
|
|
|
Args: |
|
param_name (str): name of the parameter. |
|
|
|
Returns: |
|
str: checkpoint file name. |
|
""" |
|
ckpt_path = self.weight_map[param_name] |
|
return ckpt_path |
|
|
|
def get_all_param_names(self): |
|
""" |
|
Get all the weight keys. |
|
""" |
|
return list(self.weight_map.keys()) |
|
|
|
def get_param_group_filename(self) -> Union[str, None]: |
|
""" |
|
Get the file name of param_group file if this is a checkpoint for optimizer. |
|
Returns: |
|
str: param_group file name |
|
""" |
|
filename = self.metadata.get("param_groups", None) |
|
if filename: |
|
return str(self.root_path.joinpath(filename)) |
|
else: |
|
return None |
|
|
|
def write_index_file(self, save_index_file): |
|
""" |
|
Write index file. |
|
""" |
|
save_index_file = os.path.join(self.root_path, save_index_file) |
|
index = {"metadata": self.metadata, "weight_map": self.weight_map} |
|
with open(save_index_file, "w", encoding="utf-8") as f: |
|
content = json.dumps(index, indent=2) + "\n" |
|
f.write(content)
|
|
|