|
|
|
@ -1,8 +1,8 @@
|
|
|
|
|
import json |
|
|
|
|
from pathlib import Path |
|
|
|
|
from typing import Any, List, Union |
|
|
|
|
import os |
|
|
|
|
import json |
|
|
|
|
from collections import OrderedDict |
|
|
|
|
from pathlib import Path |
|
|
|
|
from typing import Any, Dict, List, Union |
|
|
|
|
|
|
|
|
|
from .utils import is_dtensor_checkpoint |
|
|
|
|
|
|
|
|
@ -22,8 +22,10 @@ class CheckpointIndexFile:
|
|
|
|
|
|
|
|
|
|
def __init__(self, root_path=None) -> None: |
|
|
|
|
self.root_path = root_path |
|
|
|
|
self.metadata: dict = dict() |
|
|
|
|
self.weight_map: dict = dict() |
|
|
|
|
|
|
|
|
|
# use ordered dict to preserve the tensor checkpoint order |
|
|
|
|
self.metadata: Dict = OrderedDict() |
|
|
|
|
self.weight_map: Dict = OrderedDict() |
|
|
|
|
|
|
|
|
|
@staticmethod |
|
|
|
|
def from_file(index_path: Union[str, Path]): |
|
|
|
@ -150,13 +152,13 @@ class CheckpointIndexFile:
|
|
|
|
|
""" |
|
|
|
|
ckpt_path = self.weight_map[param_name] |
|
|
|
|
return ckpt_path |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_all_param_names(self): |
|
|
|
|
""" |
|
|
|
|
Get all the weight keys. |
|
|
|
|
""" |
|
|
|
|
return list(self.weight_map.keys()) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def write_index_file(self, save_index_file): |
|
|
|
|
""" |
|
|
|
|
Write index file. |
|
|
|
@ -164,5 +166,5 @@ class CheckpointIndexFile:
|
|
|
|
|
save_index_file = os.path.join(self.root_path, save_index_file) |
|
|
|
|
index = {"metadata": self.metadata, "weight_map": self.weight_map} |
|
|
|
|
with open(save_index_file, "w", encoding="utf-8") as f: |
|
|
|
|
content = json.dumps(index, indent=2, sort_keys=True) + "\n" |
|
|
|
|
content = json.dumps(index, indent=2) + "\n" |
|
|
|
|
f.write(content) |
|
|
|
|