mirror of https://github.com/hpcaitech/ColossalAI
Browse Source
* [checkpoint] refactored the API and added safetensors support * polish codepull/3442/head
Frank Lee
2 years ago
committed by
GitHub
9 changed files with 579 additions and 280 deletions
@ -1,4 +1,5 @@
|
||||
from .checkpoint_io_base import CheckpointIO, ShardCheckpointIndexFile |
||||
from .checkpoint_io_base import CheckpointIO |
||||
from .general_checkpoint_io import GeneralCheckpointIO |
||||
from .index_file import CheckpointIndexFile |
||||
|
||||
__all__ = ['CheckpointIO', 'ShardCheckpointIndexFile', 'GeneralCheckpointIO'] |
||||
__all__ = ['CheckpointIO', 'CheckpointIndexFile', 'GeneralCheckpointIO'] |
||||
|
@ -0,0 +1,150 @@
|
||||
import json |
||||
from pathlib import Path |
||||
from typing import Any, List, Union |
||||
|
||||
from .utils import is_dtensor_checkpoint |
||||
|
||||
__all__ = ['CheckpointIndexFile'] |
||||
|
||||
|
||||
class CheckpointIndexFile: |
||||
""" |
||||
This class is a data structure to keep the content in the index.json file for sharded checkpoint. |
||||
|
||||
Example: |
||||
>>> index = CheckpointIndexFile.from_file('model.index.json') |
||||
>>> index.append_metadata('model_type', 'bert') |
||||
>>> index.append_weight_map('bert.embeddings.word_embeddings.weight', 'model_0001-of-0002.bin') |
||||
>>> index.export('new_index.json') |
||||
""" |
||||
|
||||
def __init__(self) -> None: |
||||
self.root_path = None |
||||
self.metadata: dict = dict() |
||||
self.weight_map: dict = dict() |
||||
|
||||
@staticmethod |
||||
def from_file(index_path: Union[str, Path]): |
||||
""" |
||||
Create a CheckpointIndexFile object from a json file. |
||||
|
||||
Args: |
||||
index_path (str): path to the json file. |
||||
|
||||
Returns: |
||||
CheckpointIndexFile: CheckpointIndexFile object. |
||||
""" |
||||
index = CheckpointIndexFile() |
||||
index.load(index_path) |
||||
return index |
||||
|
||||
def load(self, json_path: str): |
||||
""" |
||||
Load the index file from a json file. |
||||
|
||||
Args: |
||||
json_path (str): path to the json file. |
||||
""" |
||||
# load the json file |
||||
with open(json_path, 'r') as f: |
||||
index = json.load(f) |
||||
|
||||
# assign attributes if exists |
||||
if "metadata" in index: |
||||
self.metadata = index["metadata"] |
||||
if "weight_map" in index: |
||||
self.weight_map = index["weight_map"] |
||||
|
||||
# assign the root directory for the index file |
||||
self.root_path = Path(json_path).absolute().parent |
||||
|
||||
def export(self, json_path: str): |
||||
""" |
||||
Export the index file to a json file. |
||||
|
||||
Args: |
||||
json_path (str): path to the json file. |
||||
""" |
||||
# create the index file |
||||
index = dict() |
||||
index["metadata"] = self.metadata |
||||
index["weight_map"] = self.weight_map |
||||
|
||||
# export the index file |
||||
with open(json_path, 'w') as f: |
||||
json.dump(index, f, indent=4) |
||||
|
||||
def append_weight_map(self, param_name: str, shard_file: str): |
||||
""" |
||||
Append a weight map entry to the index file. |
||||
|
||||
Args: |
||||
param_name (str): name of the parameter. |
||||
shard_file (str): name of the shard file. |
||||
""" |
||||
self.weight_map[param_name] = shard_file |
||||
|
||||
def append_meta_data(self, name: str, val: Any): |
||||
""" |
||||
Append a metadata entry to the index file. |
||||
|
||||
Args: |
||||
name (str): name of the metadata. |
||||
val (Any): value of the metadata. |
||||
""" |
||||
self.metadata[name] = val |
||||
|
||||
def contains_dtensor(self): |
||||
""" |
||||
Check if the index file contains any distributed tensor. The distributed tensors will be stored in |
||||
`dtensor/module.linear.weight.*.bin` or `dtensor/module.linear.weight.*.safetensors` in the weight map. |
||||
|
||||
Returns: |
||||
bool: True if the index file contains any distributed tensor, False otherwise. |
||||
""" |
||||
for value in self.weight_map.values(): |
||||
if value.endswith(".*.bin") or value.endswith(".*.safetensors"): |
||||
return True |
||||
return False |
||||
|
||||
def get_checkpoint_fileanames(self) -> List[str]: |
||||
""" |
||||
Get the set of checkpoint filenames in the weight map. |
||||
|
||||
Returns: |
||||
list: checkpoint shard filenames. |
||||
""" |
||||
# read the checkpoint file list from the json file and get a list of unique file names |
||||
checkpoint_files = sorted(list(set(self.weight_map.values()))) |
||||
|
||||
# get the absolute paths for all checkpoint files |
||||
checkpoint_files = [str(self.root_path.joinpath(f)) for f in checkpoint_files] |
||||
|
||||
dtensor_list = [] |
||||
checkpoint_list = [] |
||||
|
||||
for ckpt_file in checkpoint_files: |
||||
if is_dtensor_checkpoint(ckpt_file): |
||||
dtensor_list.append(ckpt_file) |
||||
else: |
||||
checkpoint_list.append(ckpt_file) |
||||
|
||||
return checkpoint_list, dtensor_list |
||||
|
||||
def assert_no_dtensor_checkpoint(self): |
||||
for val in self.weight_map.values(): |
||||
if is_dtensor_checkpoint(val): |
||||
raise ValueError(f"Checkpoint file {val} contains distributed tensor") |
||||
|
||||
def get_checkpoint_file(self, param_name: str) -> str: |
||||
""" |
||||
Get the checkpoint file name for a parameter. |
||||
|
||||
Args: |
||||
param_name (str): name of the parameter. |
||||
|
||||
Returns: |
||||
str: checkpoint file name. |
||||
""" |
||||
ckpt_path = self.weight_map[param_name] |
||||
return ckpt_path |
@ -0,0 +1,278 @@
|
||||
from pathlib import Path |
||||
from typing import List, Optional, Tuple |
||||
|
||||
import torch |
||||
|
||||
# ====================================== |
||||
# General helper functions |
||||
# ====================================== |
||||
|
||||
|
||||
def calculate_tensor_size(tensor: torch.Tensor) -> float: |
||||
""" |
||||
Calculate the size of a parameter in MB. Used to compute whether a group of params exceed the shard size. |
||||
If so, a new shard should be created. |
||||
|
||||
Args: |
||||
tenosr (torch.Tensor): the tensor to calculate size for. |
||||
|
||||
Returns: |
||||
float: size of the tensor in MB. |
||||
""" |
||||
return tensor.numel() * tensor.element_size() / 1024 / 1024 |
||||
|
||||
|
||||
def is_safetensors_available() -> bool: |
||||
""" |
||||
Check whether safetensors is available. |
||||
|
||||
Returns: |
||||
bool: whether safetensors is available. |
||||
""" |
||||
try: |
||||
import safetensors |
||||
return True |
||||
except ImportError: |
||||
return False |
||||
|
||||
|
||||
def is_dtensor_checkpoint(checkpoint_file_path: str) -> bool: |
||||
""" |
||||
Check whether the checkpoint file is a dtensor checkpoint. |
||||
|
||||
Args: |
||||
checkpoint_file_path (str): path to the checkpoint file. |
||||
|
||||
Returns: |
||||
bool: whether the checkpoint file is a dtensor checkpoint. |
||||
""" |
||||
if checkpoint_file_path.endswith('.*.safetensors') or checkpoint_file_path.endswith('.*.bin'): |
||||
return True |
||||
else: |
||||
return False |
||||
|
||||
|
||||
def is_safetensor_checkpoint(checkpoint_file_path: str) -> bool: |
||||
""" |
||||
Check whether the checkpoint file is a safetensor checkpoint. |
||||
|
||||
Args: |
||||
checkpoint_file_path (str): path to the checkpoint file. |
||||
|
||||
Returns: |
||||
bool: whether the checkpoint file is a safetensor checkpoint. |
||||
""" |
||||
if checkpoint_file_path.endswith('.safetensors'): |
||||
return True |
||||
else: |
||||
return False |
||||
|
||||
|
||||
# ====================================== |
||||
# Helper functions for saving state dict |
||||
# ====================================== |
||||
|
||||
|
||||
def save_state_dict(state_dict: dict, checkpoint_file_path: str, use_safetensors: bool) -> None: |
||||
""" |
||||
Save state dict to checkpoint. |
||||
|
||||
Args: |
||||
state_dict (dict): state dict. |
||||
checkpoint_file_path (str): path to the checkpoint file. |
||||
use_safetensors (bool): whether to use safetensors to save the checkpoint. |
||||
""" |
||||
if use_safetensors: |
||||
assert is_safetensors_available(), "safetensors is not available." |
||||
assert checkpoint_file_path.endswith('.safetensors'), \ |
||||
"safetensors only supports .safetensors suffix for checkpoint file." |
||||
from safetensors.torch import save_file |
||||
save_file(state_dict, checkpoint_file_path) |
||||
else: |
||||
torch.save(state_dict, checkpoint_file_path) |
||||
|
||||
|
||||
def save_dtensor(name: str, tensor: torch.Tensor, index_file: "CheckpointIndexFile", use_safetensors: bool) -> None: |
||||
""" |
||||
Save distributed tensor to checkpoint. This checkpoint will be a dictionary which contains |
||||
only one tensor. |
||||
|
||||
Args: |
||||
tensor (Tensor): tensor to be saved. |
||||
index_file (CheckpointIndexFile): path to the checkpoint file. |
||||
size_per_shard (int): size per shard in MB. |
||||
""" |
||||
root_path = index_file.root_path |
||||
output_root_path = root_path.joinpath('dtensor') |
||||
|
||||
# create directory |
||||
output_root_path.mkdir(exist_ok=True) |
||||
|
||||
# save tensor to this directory |
||||
# TODO(YuliangLiu): get index of the tensor shard |
||||
# e.g. index = |
||||
index = 0 |
||||
|
||||
# save tensor to file |
||||
ckpt_file_name = generate_dtensor_file_name(name, index, use_safetensors) |
||||
ckpt_file_path = output_root_path.joinpath(ckpt_file_name) |
||||
|
||||
# dtensor ckpt file always contains only one tensor |
||||
state_dict = {name: tensor} |
||||
save_state_dict(state_dict, str(ckpt_file_path), use_safetensors) |
||||
|
||||
# update the weight map |
||||
# * means all shards |
||||
ckpt_file_name_in_weight_map = 'dtensor/' + generate_dtensor_file_name(name, '*', use_safetensors) |
||||
index_file.append_weight_map(name, ckpt_file_name_in_weight_map) |
||||
|
||||
|
||||
def get_checkpoint_file_suffix(use_safetensors: bool) -> str: |
||||
""" |
||||
Get checkpoint file suffix. |
||||
|
||||
Args: |
||||
use_safetensors (bool): whether to use safetensors to save the checkpoint. |
||||
|
||||
Returns: |
||||
str: checkpoint file suffix. |
||||
""" |
||||
if use_safetensors: |
||||
return '.safetensors' |
||||
else: |
||||
return '.bin' |
||||
|
||||
|
||||
def generate_checkpoint_shard_file_name(index: int, |
||||
total_number: int, |
||||
use_safetensors: bool, |
||||
prefix: str = None) -> str: |
||||
""" |
||||
Generate checkpoint shard file name. |
||||
|
||||
Args: |
||||
index (int): index of the shard. |
||||
total_number (int): total number of shards. |
||||
use_safetensors (bool): whether to use safetensors to save the checkpoint. |
||||
prefix (str): prefix of the shard file name. Default: None. |
||||
|
||||
Returns: |
||||
str: checkpoint shard file name. |
||||
""" |
||||
suffix = get_checkpoint_file_suffix(use_safetensors) |
||||
|
||||
if prefix is None: |
||||
return f"{index:05d}-of-{total_number:05d}.{suffix}" |
||||
else: |
||||
return f"{prefix}-{index:05d}-of-{total_number:05d}.{suffix}" |
||||
|
||||
|
||||
def generate_dtensor_file_name(param_name: str, index: int, use_safetensors: bool) -> str: |
||||
""" |
||||
Generate dtensor file name. |
||||
|
||||
Args: |
||||
param_name (str): name of the distributed parameter. |
||||
index (int): index of the shard. |
||||
use_safetensors (bool): whether to use safetensors to save the checkpoint. |
||||
|
||||
Returns: |
||||
str: dtensor file name. |
||||
""" |
||||
suffix = get_checkpoint_file_suffix(use_safetensors) |
||||
return f'{param_name}.{index}.{suffix}' |
||||
|
||||
|
||||
def save_state_dict_as_shard( |
||||
state_dict: dict, |
||||
checkpoint_path: str, |
||||
index: int, |
||||
total_number: int, |
||||
use_safetensors: bool, |
||||
prefix: str = None, |
||||
) -> None: |
||||
""" |
||||
Save state dict as shard. |
||||
|
||||
Args: |
||||
state_dict (dict): state dict. |
||||
checkpoint_path (str): path to the checkpoint file. |
||||
index (int): index of the shard. |
||||
total_number (int): total number of shards. |
||||
prefix (str): prefix of the shard file name. |
||||
use_safetensors (bool): whether to use safetensors to save the checkpoint. |
||||
""" |
||||
# generate the shard name |
||||
shard_file_name = generate_checkpoint_shard_file_name(index, total_number, use_safetensors, prefix) |
||||
shard_file_path = Path(checkpoint_path).joinpath(shard_file_name).absolute() |
||||
|
||||
# save the shard |
||||
save_state_dict(state_dict, str(shard_file_path), use_safetensors) |
||||
|
||||
|
||||
# ======================================== |
||||
# Helper functions for loading state dict |
||||
# ======================================== |
||||
|
||||
|
||||
def has_index_file(checkpoint_path: str) -> Tuple[bool, Optional[Path]]: |
||||
""" |
||||
Check whether the checkpoint has an index file. |
||||
|
||||
Args: |
||||
checkpoint_path (str): path to the checkpoint. |
||||
|
||||
Returns: |
||||
Tuple[bool, Optional[Path]]: a tuple of (has_index_file, index_file_path) |
||||
""" |
||||
checkpoint_path = Path(checkpoint_path) |
||||
if checkpoint_path.is_file(): |
||||
# check if it is .index.json |
||||
if checkpoint_path.name.endswith('.index.json'): |
||||
return True, checkpoint_path |
||||
else: |
||||
return False, None |
||||
elif checkpoint_path.is_dir(): |
||||
# check if there is only one a file ending with .index.json in this directory |
||||
index_files = list(checkpoint_path.glob('*.index.json')) |
||||
|
||||
# if we found a .index.json file, make sure there is only one |
||||
if len(index_files) > 0: |
||||
assert len( |
||||
index_files |
||||
) == 1, f'Expected to find one .index.json file in {checkpoint_path}, but found {len(index_files)}' |
||||
|
||||
if len(index_files) == 1: |
||||
return True, index_files[0] |
||||
else: |
||||
return False, None |
||||
|
||||
|
||||
def load_state_dict(checkpoint_file_path: Path): |
||||
""" |
||||
Load state dict from checkpoint. |
||||
|
||||
Args: |
||||
checkpoint_file_path (Path): path to the checkpoint file. |
||||
|
||||
Returns: |
||||
dict: state dict. |
||||
""" |
||||
|
||||
assert not is_dtensor_checkpoint(checkpoint_file_path), \ |
||||
f'Cannot load state dict from dtensor checkpoint {checkpoint_file_path}, you should convert the distributed tensors to gathered tensors with our CLI offline.' |
||||
|
||||
if is_safetensor_checkpoint(checkpoint_file_path): |
||||
assert is_safetensors_available(), \ |
||||
f'Cannot load state dict from safetensor checkpoint {checkpoint_file_path}, because safetensors is not available. Please install safetensors first with pip install safetensors.' |
||||
# load with safetensors |
||||
from safetensors import safe_open |
||||
state_dict = {} |
||||
with safe_open(checkpoint_file_path, framework="pt", device="cpu") as f: |
||||
for k in f.keys(): |
||||
state_dict[k] = f.get_tensor(k) |
||||
return state_dict |
||||
|
||||
else: |
||||
# load with torch |
||||
return torch.load(checkpoint_file_path) |
Loading…
Reference in new issue