mirror of https://github.com/hpcaitech/ColossalAI
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
375 lines
13 KiB
375 lines
13 KiB
import json
|
|
from abc import ABC, abstractmethod
|
|
from pathlib import Path
|
|
from typing import Any
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
from torch.optim import Optimizer
|
|
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
|
|
|
|
__all__ = ['CheckpointIO', 'ShardCheckpointIndexFile']
|
|
|
|
|
|
class CheckpointIO(ABC):
|
|
"""
|
|
CheckpointIO is the base class for all checkpoint IO classes. It defines the interface for checkpoint IO.
|
|
|
|
|
|
Examples:
|
|
>>> from colossalai.checkpoint_io import GeneralCheckpointIO
|
|
>>> checkpoint_io = CheckpointIO()
|
|
>>>
|
|
>>> # load model from checkpoint
|
|
>>> model = checkpoint_io.load_model(model, 'model.pt')
|
|
>>>
|
|
>>> # save model to checkpoint
|
|
>>> checkpoint_io.save_model(model, 'model.pt')
|
|
>>>
|
|
>>> # save model to sharded checkpoints
|
|
>>> checkpoint_io.save_model(model, './checkpoints/', shard=True)
|
|
>>>
|
|
>>> # load model from sharded checkpoints
|
|
>>> model = checkpoint_io.load_model(model, './checkpoints/')
|
|
>>>
|
|
>>> # load optimizer from checkpoint
|
|
>>> optimizer = checkpoint_io.load_optimizer(optimizer, 'optimizer.pt')
|
|
>>>
|
|
>>> # save optimizer to checkpoint
|
|
>>> checkpoint_io.save_optimizer(optimizer, 'optimizer.pt')
|
|
|
|
"""
|
|
|
|
# ======================================
|
|
# Abstract methods for implementation
|
|
# ======================================
|
|
|
|
@abstractmethod
|
|
def load_model(self, model: nn.Module, checkpoint: str, strict: bool = True):
|
|
"""
|
|
Load model from checkpoint.
|
|
|
|
Args:
|
|
model (nn.Module): model to be loaded.
|
|
checkpoint (str): checkpoint path. This value is made compatiblity with the model checkpoints in the
|
|
mainstream model zoos such as Hugging Face and TIMM. The checkpoint path can be:
|
|
1. a file path, e.g. 'model.pt'
|
|
2. a path to a json file which defines the index to the sharded checkpoint
|
|
3. a path to a folder containing a unique .index.json file for sharded checkpoint
|
|
strict (bool): whether to strictly enforce that the param name in
|
|
the checkpoint match the keys returned by this module's.
|
|
"""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def save_model(self,
|
|
model: nn.Module,
|
|
checkpoint: str,
|
|
prefix: str = None,
|
|
shard: bool = False,
|
|
size_per_shard: int = 1024):
|
|
"""
|
|
Save model to checkpoint.
|
|
|
|
Examples:
|
|
>>> from colossalai.checkpoint_io import GeneralCheckpointIO
|
|
>>> checkpoint_io = CheckpointIO()
|
|
>>>
|
|
>>> # save model to a single file
|
|
>>> save_model(model, 'model.pt')
|
|
>>>
|
|
>>> # save model to a sharded checkpoint
|
|
>>> save_model(model, './checkpoints/', shard=True)
|
|
|
|
Args:
|
|
model (nn.Module): model to be saved.
|
|
checkpoint: checkpoint path. The checkpoint path can be :
|
|
1. a file path, e.g. 'model.pt'
|
|
2. a directory path to save the sharded checkpoint, e.g. './checkpoints/' when shard = True.
|
|
shard: whether to shard the checkpoint. Default: False. If set to True, the checkpoint will be sharded into
|
|
multiple files. The model shards will be specificed by a `model.index.json` file. When shard = True, please ensure
|
|
that the checkpoint path is a directory path instead of a file path.
|
|
size_per_shard (int): size per shard in MB. Default: 1024. This value is only used when shard is set to True.
|
|
"""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def load_optimizer(self, optimizer: Optimizer, checkpoint: str):
|
|
"""
|
|
Load optimizer from checkpoint.
|
|
|
|
Args:
|
|
optimizer (Optimizer): optimizer to be loaded.
|
|
checkpoint (str): checkpoint path. This value is made compatiblity with the model checkpoints in the
|
|
"""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def save_optimizer(self, optimizer: Optimizer, checkpoint: str, shard: bool = False, size_per_shard: int = 1024):
|
|
"""
|
|
Save optimizer to checkpoint.
|
|
|
|
Args:
|
|
optimizer (Optimizer): optimizer to be saved.
|
|
checkpoint: checkpoint path. The checkpoint path can be :
|
|
1. a file path, e.g. 'model.pt'
|
|
2. a path to a json file which defines the index to the sharded checkpoint for the optimizer
|
|
3. a path to a folder containing a unique .index.json file for sharded checkpoint
|
|
"""
|
|
pass
|
|
|
|
# ============================================
|
|
# methods for loading and saving lr scheduler
|
|
# as this is quite standard, there is no need
|
|
# to make them abstract
|
|
# ============================================
|
|
|
|
def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str):
|
|
"""
|
|
Save lr scheduler to checkpoint.
|
|
|
|
Args:
|
|
lr_scheduler (LRScheduler): lr scheduler to be saved.
|
|
checkpoint: checkpoint path. The checkpoint path can only be a file path.
|
|
"""
|
|
torch.save(lr_scheduler.state_dict(), checkpoint)
|
|
|
|
def load_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str):
|
|
"""
|
|
Load lr scheduler from checkpoint.
|
|
|
|
Args:
|
|
lr_scheduler (LRScheduler): lr scheduler to be loaded.
|
|
checkpoint (str): the path for a single checkpoint file.
|
|
"""
|
|
state_dict = torch.load(checkpoint)
|
|
lr_scheduler.load_state_dict(state_dict)
|
|
|
|
# ========================================
|
|
# Helper functions for loading state dict
|
|
# ========================================
|
|
|
|
def get_sharded_checkpoint_index_file(self, checkpoint_path: Path):
|
|
"""
|
|
Get the index file path for a sharded checkpoint.
|
|
|
|
Args:
|
|
checkpoint_path (Path): path to the checkpoint.
|
|
|
|
Returns:
|
|
Path: path to the index file.
|
|
"""
|
|
if checkpoint_path.is_file():
|
|
# check if it is .index.json
|
|
if checkpoint_path.name.endswith('.index.json'):
|
|
return checkpoint_path
|
|
else:
|
|
raise ValueError(f'Invalid checkpoint path: {checkpoint_path}. ')
|
|
elif checkpoint_path.is_dir():
|
|
# check if there is only one a file ending with .index.json in this directory
|
|
index_files = list(checkpoint_path.glob('*.index.json'))
|
|
if len(index_files) == 1:
|
|
return index_files[0]
|
|
else:
|
|
raise ValueError(f'Found {len(index_files)} index files in {checkpoint_path}. ')
|
|
|
|
def is_sharded_checkpoint(self, checkpoint_path: Path):
|
|
"""
|
|
Check whether the checkpoint is sharded.
|
|
|
|
Args:
|
|
checkpoint (str): checkpoint path.
|
|
|
|
Returns:
|
|
bool: whether the checkpoint is sharded.
|
|
"""
|
|
if checkpoint_path.is_file():
|
|
# check if it is .index.json
|
|
if checkpoint_path.name.endswith('.index.json'):
|
|
return True
|
|
else:
|
|
return False
|
|
elif checkpoint_path.is_dir():
|
|
# check if there is only one a file ending with .index.json in this directory
|
|
index_files = list(checkpoint_path.glob('*.index.json'))
|
|
if len(index_files) == 1:
|
|
return True
|
|
else:
|
|
raise ValueError(f'Found {len(index_files)} index files in {checkpoint_path}. ')
|
|
|
|
def get_checkpoint_shard_filenames(self, index_file_path: Path):
|
|
"""
|
|
Get checkpoint shard filenames from a json file.
|
|
|
|
Args:
|
|
index_file_path (Path): path to the json file.
|
|
|
|
Returns:
|
|
list: checkpoint shard filenames.
|
|
"""
|
|
with open(str(index_file_path), 'r') as f:
|
|
shard_filenames = json.load(f)
|
|
|
|
if "weight_map" in index:
|
|
index = index["weight_map"]
|
|
|
|
checkpoint_root_path = index_file_path.absolute().parent
|
|
|
|
# read the checkpoint file list from the json file and get a list of unique file names
|
|
checkpoint_files = sorted(list(set(index.values())))
|
|
|
|
# get the absolute paths for all checkpoint files
|
|
checkpoint_files = [checkpoint_root_path.joinpath(f) for f in checkpoint_files]
|
|
return shard_filenames
|
|
|
|
def load_safetensors_state_dict(self, *args, **kwargs):
|
|
"""
|
|
Load safetensors state dict from checkpoint.
|
|
"""
|
|
# TODO(FrankLeeeee): support huggingface safetensors
|
|
raise NotImplementedError("This method is not implemented to support safe tensors")
|
|
|
|
def load_state_dict(self, checkpoint_file_path: Path):
|
|
"""
|
|
Load state dict from checkpoint.
|
|
|
|
Args:
|
|
checkpoint_file_path (Path): path to the checkpoint file.
|
|
|
|
Returns:
|
|
dict: state dict.
|
|
"""
|
|
return torch.load(str(checkpoint_file_path))
|
|
|
|
# ======================================
|
|
# Helper functions for saving state dict
|
|
# ======================================
|
|
|
|
def save_safetensors_state_dict(self, *args, **kwargs):
|
|
"""
|
|
Save safetensors state dict to checkpoint.
|
|
"""
|
|
# TODO(FrankLeeeee): support huggingface safetensors
|
|
raise NotImplementedError("This method is not implemented to support safe tensors")
|
|
|
|
def generate_checkpoint_shard_file_name(self, index: int, total_number: int, prefix: str = None):
|
|
"""
|
|
Generate checkpoint shard file name.
|
|
|
|
Args:
|
|
index (int): index of the shard.
|
|
total_number (int): total number of shards.
|
|
prefix (str): prefix of the shard file name. Default: None.
|
|
"""
|
|
if prefix is None:
|
|
return f"{index}-of-{total_number}.bin"
|
|
else:
|
|
return f"{prefix}-{index}-of-{total_number}.bin"
|
|
|
|
def save_checkpoint(self, state_dict: dict, checkpoint_file_path: Path):
|
|
"""
|
|
Save state dict to checkpoint.
|
|
|
|
Args:
|
|
state_dict (dict): state dict.
|
|
checkpoint_file_path (Path): path to the checkpoint file.
|
|
"""
|
|
torch.save(state_dict, str(checkpoint_file_path))
|
|
|
|
def save_state_dict_as_shard(self, state_dict: dict, index: int, total_number: int, prefix: str,
|
|
checkpoint_path: Path):
|
|
"""
|
|
Save state dict as shard.
|
|
|
|
Args:
|
|
state_dict (dict): state dict.
|
|
checkpoint_path (Path): path to the checkpoint file.
|
|
"""
|
|
# generate the shard name
|
|
shard_file_name = self.generate_checkpoint_shard_file_name(index, total_number, prefix)
|
|
shard_file_path = checkpoint_path.joinpath(shard_file_name)
|
|
|
|
# save the shard
|
|
self.save_checkpoint(state_dict, shard_file_path)
|
|
|
|
def calculate_param_size(self, param: torch.Tensor):
|
|
"""
|
|
Calculate the size of a parameter in MB. Used to compute whether a group of params exceed the shard size.
|
|
If so, a new shard should be created.
|
|
|
|
ArgsL
|
|
param (torch.Tensor): parameter tensor.
|
|
"""
|
|
# TODO(FrankLeeeee): check if this tensor is a DTensor, compute its global size if so
|
|
return param.numel() * param.element_size() / 1024 / 1024
|
|
|
|
|
|
class ShardCheckpointIndexFile:
|
|
"""
|
|
This class is a data structure to keep the content in the index.json file for sharded checkpoint.
|
|
|
|
Example:
|
|
>>> index = ShardCheckpointIndexFile()
|
|
>>> index.load('index.json')
|
|
>>> index.append_metadata('model_type', 'bert')
|
|
>>> index.append_weight_map('bert.embeddings.word_embeddings.weight', 'bert.embeddings.word_embeddings.weight-0-of-2.bin')
|
|
>>> index.export('index.json')
|
|
"""
|
|
|
|
def __init__(self) -> None:
|
|
self.metadata: dict = dict()
|
|
self.weight_map: dict = dict()
|
|
|
|
def load(self, json_path: str):
|
|
"""
|
|
Load the index file from a json file.
|
|
|
|
Args:
|
|
json_path (str): path to the json file.
|
|
"""
|
|
# load the json file
|
|
with open(json_path, 'r') as f:
|
|
index = json.load(f)
|
|
|
|
# assign attributes if exists
|
|
if "metadata" in index:
|
|
self.metadata = index["metadata"]
|
|
if "weight_map" in index:
|
|
self.weight_map = index["weight_map"]
|
|
|
|
def export(self, json_path: str):
|
|
"""
|
|
Export the index file to a json file.
|
|
|
|
Args:
|
|
json_path (str): path to the json file.
|
|
"""
|
|
# create the index file
|
|
index = dict()
|
|
index["metadata"] = self.metadata
|
|
index["weight_map"] = self.weight_map
|
|
|
|
# export the index file
|
|
with open(json_path, 'w') as f:
|
|
json.dump(index, f, indent=4)
|
|
|
|
def append_weight_map(self, param_name: str, shard_file: str):
|
|
"""
|
|
Append a weight map entry to the index file.
|
|
|
|
Args:
|
|
param_name (str): name of the parameter.
|
|
shard_file (str): name of the shard file.
|
|
"""
|
|
self.weight_map[param_name] = shard_file
|
|
|
|
def append_meta_data(self, name: str, val: Any):
|
|
"""
|
|
Append a metadata entry to the index file.
|
|
|
|
Args:
|
|
name (str): name of the metadata.
|
|
val (Any): value of the metadata.
|
|
"""
|
|
self.metadata[name] = val
|