mirror of https://github.com/hpcaitech/ColossalAI
134 lines
5.1 KiB
Python
134 lines
5.1 KiB
Python
from pathlib import Path
|
|
|
|
import torch.nn as nn
|
|
from torch.optim import Optimizer
|
|
import logging
|
|
import os
|
|
import json
|
|
import gc
|
|
|
|
from .checkpoint_io_base import CheckpointIO
|
|
from .index_file import CheckpointIndexFile
|
|
from .utils import (
|
|
has_index_file,
|
|
load_state_dict,
|
|
save_state_dict,
|
|
is_safetensors_available,
|
|
shard_checkpoint,
|
|
load_shard_state_dict,
|
|
load_state_dict_into_model
|
|
)
|
|
from .utils import SAFE_WEIGHTS_NAME, WEIGHTS_NAME, SAFE_WEIGHTS_INDEX_NAME, WEIGHTS_INDEX_NAME
|
|
|
|
__all__ = ['GeneralCheckpointIO']
|
|
|
|
|
|
class GeneralCheckpointIO(CheckpointIO):
|
|
"""
|
|
Checkpoint IO
|
|
"""
|
|
def load_unsharded_model(self, model: nn.Module, checkpoint: str, strict: bool):
|
|
checkpoint = load_state_dict(checkpoint)
|
|
model.load_state_dict(checkpoint, strict=strict)
|
|
|
|
def save_unsharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, use_safetensors: bool):
|
|
state_dict = model.state_dict()
|
|
|
|
# TODO(FrankLeeeee): add support for gather_dtensor
|
|
if gather_dtensor:
|
|
pass
|
|
|
|
# save the checkpoint
|
|
save_state_dict(state_dict, checkpoint, use_safetensors)
|
|
|
|
def load_sharded_optimizer(self, optimizer: Optimizer, checkpoint: Path, prefix: str, size_per_shard: int):
|
|
raise NotImplementedError("Sharded optimizer checkpoint is not supported yet.")
|
|
|
|
def load_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path):
|
|
checkpoint = load_state_dict(checkpoint)
|
|
optimizer.load_state_dict(checkpoint)
|
|
|
|
def save_sharded_optimizer(
|
|
self,
|
|
optimizer: Optimizer,
|
|
checkpoint: Path,
|
|
gather_dtensor: bool,
|
|
prefix: str,
|
|
size_per_shard: int,
|
|
):
|
|
raise NotImplementedError("Sharded optimizer checkpoint is not supported yet.")
|
|
|
|
def save_unsharded_optimizer(
|
|
self,
|
|
optimizer: Optimizer,
|
|
checkpoint: Path,
|
|
gather_dtensor: bool,
|
|
):
|
|
# TODO(FrankLeeeee): handle distributed tensors
|
|
save_state_dict(optimizer.state_dict(), checkpoint, use_safetensors=False)
|
|
|
|
|
|
def save_sharded_model(self, model: nn.Module, checkpoint_path: str, gather_dtensor:bool = False,
|
|
prefix: str = "", max_shard_size: int = 1024, use_safetensors: bool = False):
|
|
"""
|
|
implement this method as it can be supported by Huggingface model,
|
|
save shard model, save model to multiple files
|
|
"""
|
|
if os.path.isfile(checkpoint_path):
|
|
logging.error(f"Provided path ({checkpoint_path}) should be a directory, not a file")
|
|
return
|
|
|
|
Path(checkpoint_path).mkdir(parents=True, exist_ok=True)
|
|
|
|
# shard checkpoint
|
|
state_dict = model.state_dict()
|
|
weights_name = SAFE_WEIGHTS_NAME if use_safetensors else WEIGHTS_NAME
|
|
shards, index = shard_checkpoint(state_dict, max_shard_size=max_shard_size, weights_name=weights_name)
|
|
|
|
# Save the model
|
|
for shard_file, shard in shards.items():
|
|
checkpoint_file_path = os.path.join(checkpoint_path, shard_file)
|
|
save_state_dict(shard, checkpoint_file_path, use_safetensors)
|
|
|
|
# save index file
|
|
save_index_file = SAFE_WEIGHTS_INDEX_NAME if use_safetensors else WEIGHTS_INDEX_NAME
|
|
save_index_file = os.path.join(checkpoint_path, save_index_file)
|
|
with open(save_index_file, "w", encoding="utf-8") as f:
|
|
content = json.dumps(index, indent=2, sort_keys=True) + "\n"
|
|
f.write(content)
|
|
logging.info(
|
|
f"The model is bigger than the maximum size per checkpoint ({max_shard_size}) and is going to be "
|
|
f"split in {len(shards)} checkpoint shards. You can find where each parameters has been saved in the "
|
|
f"index located at {save_index_file}."
|
|
)
|
|
|
|
|
|
def load_sharded_model(self, model: nn.Module, checkpoint_index_file: Path, strict: bool = False, use_safetensors: bool = False):
|
|
"""
|
|
load shard model, load model from multiple files
|
|
"""
|
|
use_safetensors = False
|
|
if "safetensors" in checkpoint_index_file.name:
|
|
use_safetensors = True
|
|
|
|
if use_safetensors and not is_safetensors_available():
|
|
raise ImportError("`safe_serialization` requires the `safetensors` library: `pip install safetensors`.")
|
|
|
|
# read checkpoint index file
|
|
ckpt_index_file = CheckpointIndexFile.from_file(checkpoint_index_file)
|
|
checkpoint_files, _ = ckpt_index_file.get_checkpoint_fileanames()
|
|
missing_keys = ckpt_index_file.get_all_param_names()
|
|
|
|
for shard_file in checkpoint_files:
|
|
state_dict = load_shard_state_dict(Path(shard_file), use_safetensors)
|
|
load_state_dict_into_model(model, state_dict, missing_keys, strict)
|
|
del state_dict
|
|
gc.collect()
|
|
|
|
if strict and len(missing_keys) > 0:
|
|
error_msgs = 'Missing key(s) in state_dict: {}. '.format(
|
|
', '.join('"{}"'.format(k) for k in missing_keys))
|
|
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
|
|
self.__class__.__name__, "\n\t".join(error_msgs)))
|
|
|