[api] implemented the checkpoint io module (#3205)

* [api] implemented the checkpoint io module

* polish code

* polish code
pull/3213/head
Frank Lee 2023-03-23 10:53:17 +08:00 committed by GitHub
parent f8289d4221
commit cd142fbefa
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 514 additions and 0 deletions

View File

@ -0,0 +1,4 @@
from .checkpoint_io_base import CheckpointIO, ShardCheckpointIndexFile
from .general_checkpoint_io import GeneralCheckpointIO
__all__ = ['CheckpointIO', 'ShardCheckpointIndexFile', 'GeneralCheckpointIO']

View File

@ -0,0 +1,374 @@
import json
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Any
import torch
import torch.nn as nn
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
__all__ = ['CheckpointIO', 'ShardCheckpointIndexFile']
class CheckpointIO(ABC):
"""
CheckpointIO is the base class for all checkpoint IO classes. It defines the interface for checkpoint IO.
Examples:
>>> from colossalai.checkpoint_io import GeneralCheckpointIO
>>> checkpoint_io = CheckpointIO()
>>>
>>> # load model from checkpoint
>>> model = checkpoint_io.load_model(model, 'model.pt')
>>>
>>> # save model to checkpoint
>>> checkpoint_io.save_model(model, 'model.pt')
>>>
>>> # save model to sharded checkpoints
>>> checkpoint_io.save_model(model, './checkpoints/', shard=True)
>>>
>>> # load model from sharded checkpoints
>>> model = checkpoint_io.load_model(model, './checkpoints/')
>>>
>>> # load optimizer from checkpoint
>>> optimizer = checkpoint_io.load_optimizer(optimizer, 'optimizer.pt')
>>>
>>> # save optimizer to checkpoint
>>> checkpoint_io.save_optimizer(optimizer, 'optimizer.pt')
"""
# ======================================
# Abstract methods for implementation
# ======================================
@abstractmethod
def load_model(self, model: nn.Module, checkpoint: str, strict: bool = True):
"""
Load model from checkpoint.
Args:
model (nn.Module): model to be loaded.
checkpoint (str): checkpoint path. This value is made compatiblity with the model checkpoints in the
mainstream model zoos such as Hugging Face and TIMM. The checkpoint path can be:
1. a file path, e.g. 'model.pt'
2. a path to a json file which defines the index to the sharded checkpoint
3. a path to a folder containing a unique .index.json file for sharded checkpoint
strict (bool): whether to strictly enforce that the param name in
the checkpoint match the keys returned by this module's.
"""
pass
@abstractmethod
def save_model(self,
model: nn.Module,
checkpoint: str,
prefix: str = None,
shard: bool = False,
size_per_shard: int = 1024):
"""
Save model to checkpoint.
Examples:
>>> from colossalai.checkpoint_io import GeneralCheckpointIO
>>> checkpoint_io = CheckpointIO()
>>>
>>> # save model to a single file
>>> save_model(model, 'model.pt')
>>>
>>> # save model to a sharded checkpoint
>>> save_model(model, './checkpoints/', shard=True)
Args:
model (nn.Module): model to be saved.
checkpoint: checkpoint path. The checkpoint path can be :
1. a file path, e.g. 'model.pt'
2. a directory path to save the sharded checkpoint, e.g. './checkpoints/' when shard = True.
shard: whether to shard the checkpoint. Default: False. If set to True, the checkpoint will be sharded into
multiple files. The model shards will be specificed by a `model.index.json` file. When shard = True, please ensure
that the checkpoint path is a directory path instead of a file path.
size_per_shard (int): size per shard in MB. Default: 1024. This value is only used when shard is set to True.
"""
pass
@abstractmethod
def load_optimizer(self, optimizer: Optimizer, checkpoint: str):
"""
Load optimizer from checkpoint.
Args:
optimizer (Optimizer): optimizer to be loaded.
checkpoint (str): checkpoint path. This value is made compatiblity with the model checkpoints in the
"""
pass
@abstractmethod
def save_optimizer(self, optimizer: Optimizer, checkpoint: str, shard: bool = False, size_per_shard: int = 1024):
"""
Save optimizer to checkpoint.
Args:
optimizer (Optimizer): optimizer to be saved.
checkpoint: checkpoint path. The checkpoint path can be :
1. a file path, e.g. 'model.pt'
2. a path to a json file which defines the index to the sharded checkpoint for the optimizer
3. a path to a folder containing a unique .index.json file for sharded checkpoint
"""
pass
# ============================================
# methods for loading and saving lr scheduler
# as this is quite standard, there is no need
# to make them abstract
# ============================================
def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str):
"""
Save lr scheduler to checkpoint.
Args:
lr_scheduler (LRScheduler): lr scheduler to be saved.
checkpoint: checkpoint path. The checkpoint path can only be a file path.
"""
torch.save(lr_scheduler.state_dict(), checkpoint)
def load_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str):
"""
Load lr scheduler from checkpoint.
Args:
lr_scheduler (LRScheduler): lr scheduler to be loaded.
checkpoint (str): the path for a single checkpoint file.
"""
state_dict = torch.load(checkpoint)
lr_scheduler.load_state_dict(state_dict)
# ========================================
# Helper functions for loading state dict
# ========================================
def get_sharded_checkpoint_index_file(self, checkpoint_path: Path):
"""
Get the index file path for a sharded checkpoint.
Args:
checkpoint_path (Path): path to the checkpoint.
Returns:
Path: path to the index file.
"""
if checkpoint_path.is_file():
# check if it is .index.json
if checkpoint_path.name.endswith('.index.json'):
return checkpoint_path
else:
raise ValueError(f'Invalid checkpoint path: {checkpoint_path}. ')
elif checkpoint_path.is_dir():
# check if there is only one a file ending with .index.json in this directory
index_files = list(checkpoint_path.glob('*.index.json'))
if len(index_files) == 1:
return index_files[0]
else:
raise ValueError(f'Found {len(index_files)} index files in {checkpoint_path}. ')
def is_sharded_checkpoint(self, checkpoint_path: Path):
"""
Check whether the checkpoint is sharded.
Args:
checkpoint (str): checkpoint path.
Returns:
bool: whether the checkpoint is sharded.
"""
if checkpoint_path.is_file():
# check if it is .index.json
if checkpoint_path.name.endswith('.index.json'):
return True
else:
return False
elif checkpoint_path.is_dir():
# check if there is only one a file ending with .index.json in this directory
index_files = list(checkpoint_path.glob('*.index.json'))
if len(index_files) == 1:
return True
else:
raise ValueError(f'Found {len(index_files)} index files in {checkpoint_path}. ')
def get_checkpoint_shard_filenames(self, index_file_path: Path):
"""
Get checkpoint shard filenames from a json file.
Args:
index_file_path (Path): path to the json file.
Returns:
list: checkpoint shard filenames.
"""
with open(str(index_file_path), 'r') as f:
shard_filenames = json.load(f)
if "weight_map" in index:
index = index["weight_map"]
checkpoint_root_path = index_file_path.absolute().parent
# read the checkpoint file list from the json file and get a list of unique file names
checkpoint_files = sorted(list(set(index.values())))
# get the absolute paths for all checkpoint files
checkpoint_files = [checkpoint_root_path.joinpath(f) for f in checkpoint_files]
return shard_filenames
def load_safetensors_state_dict(self, *args, **kwargs):
"""
Load safetensors state dict from checkpoint.
"""
# TODO(FrankLeeeee): support huggingface safetensors
raise NotImplementedError("This method is not implemented to support safe tensors")
def load_state_dict(self, checkpoint_file_path: Path):
"""
Load state dict from checkpoint.
Args:
checkpoint_file_path (Path): path to the checkpoint file.
Returns:
dict: state dict.
"""
return torch.load(str(checkpoint_file_path))
# ======================================
# Helper functions for saving state dict
# ======================================
def save_safetensors_state_dict(self, *args, **kwargs):
"""
Save safetensors state dict to checkpoint.
"""
# TODO(FrankLeeeee): support huggingface safetensors
raise NotImplementedError("This method is not implemented to support safe tensors")
def generate_checkpoint_shard_file_name(self, index: int, total_number: int, prefix: str = None):
"""
Generate checkpoint shard file name.
Args:
index (int): index of the shard.
total_number (int): total number of shards.
prefix (str): prefix of the shard file name. Default: None.
"""
if prefix is None:
return f"{index}-of-{total_number}.bin"
else:
return f"{prefix}-{index}-of-{total_number}.bin"
def save_checkpoint(self, state_dict: dict, checkpoint_file_path: Path):
"""
Save state dict to checkpoint.
Args:
state_dict (dict): state dict.
checkpoint_file_path (Path): path to the checkpoint file.
"""
torch.save(state_dict, str(checkpoint_file_path))
def save_state_dict_as_shard(self, state_dict: dict, index: int, total_number: int, prefix: str,
checkpoint_path: Path):
"""
Save state dict as shard.
Args:
state_dict (dict): state dict.
checkpoint_path (Path): path to the checkpoint file.
"""
# generate the shard name
shard_file_name = self.generate_checkpoint_shard_file_name(index, total_number, prefix)
shard_file_path = checkpoint_path.joinpath(shard_file_name)
# save the shard
self.save_checkpoint(state_dict, shard_file_path)
def calculate_param_size(self, param: torch.Tensor):
"""
Calculate the size of a parameter in MB. Used to compute whether a group of params exceed the shard size.
If so, a new shard should be created.
ArgsL
param (torch.Tensor): parameter tensor.
"""
# TODO(FrankLeeeee): check if this tensor is a DTensor, compute its global size if so
return param.numel() * param.element_size() / 1024 / 1024
class ShardCheckpointIndexFile:
"""
This class is a data structure to keep the content in the index.json file for sharded checkpoint.
Example:
>>> index = ShardCheckpointIndexFile()
>>> index.load('index.json')
>>> index.append_metadata('model_type', 'bert')
>>> index.append_weight_map('bert.embeddings.word_embeddings.weight', 'bert.embeddings.word_embeddings.weight-0-of-2.bin')
>>> index.export('index.json')
"""
def __init__(self) -> None:
self.metadata: dict = dict()
self.weight_map: dict = dict()
def load(self, json_path: str):
"""
Load the index file from a json file.
Args:
json_path (str): path to the json file.
"""
# load the json file
with open(json_path, 'r') as f:
index = json.load(f)
# assign attributes if exists
if "metadata" in index:
self.metadata = index["metadata"]
if "weight_map" in index:
self.weight_map = index["weight_map"]
def export(self, json_path: str):
"""
Export the index file to a json file.
Args:
json_path (str): path to the json file.
"""
# create the index file
index = dict()
index["metadata"] = self.metadata
index["weight_map"] = self.weight_map
# export the index file
with open(json_path, 'w') as f:
json.dump(index, f, indent=4)
def append_weight_map(self, param_name: str, shard_file: str):
"""
Append a weight map entry to the index file.
Args:
param_name (str): name of the parameter.
shard_file (str): name of the shard file.
"""
self.weight_map[param_name] = shard_file
def append_meta_data(self, name: str, val: Any):
"""
Append a metadata entry to the index file.
Args:
name (str): name of the metadata.
val (Any): value of the metadata.
"""
self.metadata[name] = val

View File

@ -0,0 +1,66 @@
from pathlib import Path
import torch.nn as nn
from torch.optim import Optimizer
from .checkpoint_io_base import CheckpointIO
__all__ = ['GeneralCheckpointIO']
class GeneralCheckpointIO(CheckpointIO):
def load_model(self, model: nn.Module, checkpoint: str, strict: bool = True):
checkpoint = Path(checkpoint)
is_sharded = self.is_sharded_checkpoint(checkpoint)
if not is_sharded:
checkpoint = self.load_state_dict(checkpoint)
model.load_state_dict(checkpoint, strict=strict)
else:
# find the index file
checkpoint_path = Path(checkpoint)
index_file_path = self.get_sharded_checkpoint_index_file(checkpoint_path)
# iterate over the shard checkpoint files
# and load each
shard_files = self.get_checkpoint_shard_filenames(index_file_path)
for shard_file in shard_files:
shard_checkpoint = self.load_state_dict(shard_file)
model.load_state_dict(shard_checkpoint, strict=strict)
return model
def save_model(self,
model: nn.Module,
checkpoint: str,
prefix: str = None,
shard: bool = False,
size_per_shard: int = 1024):
checkpoint = Path(checkpoint)
if shard:
# TODO(FrankLeeeee): implement checkpoint saving to sharded checkpoint
raise NotImplementedError("Not implemented yet")
else:
self.save_checkpoint(model.state_dict(), checkpoint)
def load_optimizer(self, optimizer: Optimizer, checkpoint: str):
checkpoint = Path(checkpoint)
is_sharded = self.is_sharded_checkpoint(checkpoint)
if not is_sharded:
checkpoint = self.load_state_dict(checkpoint)
optimizer.load_state_dict(checkpoint)
else:
# TODO(FrankLeeeee): implement checkpoint loading from sharded checkpoint
# This is not an urgent feature, so we can leave it for later
# let's implement this when we test large-scale models
pass
return optimizer
def save_optimizer(self, optimizer: Optimizer, checkpoint: str, shard: bool = False, size_per_shard: int = 1024):
if shard:
# TODO(FrankLeeeee): implement checkpoint saving to sharded checkpoint
pass
else:
self.save_checkpoint(optimizer.state_dict(), checkpoint)

View File

@ -0,0 +1,70 @@
import tempfile
import torch
from torch.optim import Adam
from torchvision.models import resnet18
from colossalai.checkpoint_io import GeneralCheckpointIO
# ========
# Note:
# 1. due to checkpoint IO can be quite slow if tested with all models, we will only test on resnet for now
# 2. we will test on both sharded and unsharded checkpoints
# 3. TODO(FrankLeeeee): implement sharded checkpoint and test it
# ========
def test_unsharded_checkpoint():
# create a model and optimizer
model = resnet18()
optimizer = Adam(model.parameters(), lr=0.001)
# create test data sample
x = torch.randn(1, 3, 224, 224)
# run fwd and bwd
y = model(x)
loss = y.sum()
loss.backward()
optimizer.step()
# create a temp file for checkpoint
model_ckpt_tempfile = tempfile.NamedTemporaryFile()
optimizer_ckpt_tempfile = tempfile.NamedTemporaryFile()
# save the model and optimizer
ckpt_io = GeneralCheckpointIO()
ckpt_io.save_model(model, model_ckpt_tempfile.name)
ckpt_io.save_optimizer(optimizer, optimizer_ckpt_tempfile.name)
# create new model
new_model = resnet18()
new_optimizer = Adam(new_model.parameters(), lr=0.001)
# load the model and optimizer
new_model = ckpt_io.load_model(new_model, model_ckpt_tempfile.name)
new_optimizer = ckpt_io.load_optimizer(new_optimizer, optimizer_ckpt_tempfile.name)
# do recursive check for the optimizer state dict
# if the value is a dict, compare its values
# if the value is a list, comapre all elements one-by-one
# if the value is a torch.Tensor, use torch.equal
# otherwise use assertEqual
def recursive_check(d1, d2):
for k, v in d1.items():
if isinstance(v, dict):
recursive_check(v, d2[k])
elif isinstance(v, list):
for i in range(len(v)):
if isinstance(v[i], torch.Tensor):
assert torch.equal(v[i], d2[k][i])
else:
assert v[i] == d2[k][i]
elif isinstance(v, torch.Tensor):
assert torch.equal(v, d2[k])
else:
assert v == d2[k]
# check for model and optimizer state dict recursively
recursive_check(model.state_dict(), new_model.state_dict())
recursive_check(optimizer.state_dict(), new_optimizer.state_dict())