diff --git a/colossalai/checkpoint_io/__init__.py b/colossalai/checkpoint_io/__init__.py new file mode 100644 index 000000000..3cec630b2 --- /dev/null +++ b/colossalai/checkpoint_io/__init__.py @@ -0,0 +1,4 @@ +from .checkpoint_io_base import CheckpointIO, ShardCheckpointIndexFile +from .general_checkpoint_io import GeneralCheckpointIO + +__all__ = ['CheckpointIO', 'ShardCheckpointIndexFile', 'GeneralCheckpointIO'] diff --git a/colossalai/checkpoint_io/checkpoint_io_base.py b/colossalai/checkpoint_io/checkpoint_io_base.py new file mode 100644 index 000000000..00a65424b --- /dev/null +++ b/colossalai/checkpoint_io/checkpoint_io_base.py @@ -0,0 +1,374 @@ +import json +from abc import ABC, abstractmethod +from pathlib import Path +from typing import Any + +import torch +import torch.nn as nn +from torch.optim import Optimizer +from torch.optim.lr_scheduler import _LRScheduler as LRScheduler + +__all__ = ['CheckpointIO', 'ShardCheckpointIndexFile'] + + +class CheckpointIO(ABC): + """ + CheckpointIO is the base class for all checkpoint IO classes. It defines the interface for checkpoint IO. + + + Examples: + >>> from colossalai.checkpoint_io import GeneralCheckpointIO + >>> checkpoint_io = CheckpointIO() + >>> + >>> # load model from checkpoint + >>> model = checkpoint_io.load_model(model, 'model.pt') + >>> + >>> # save model to checkpoint + >>> checkpoint_io.save_model(model, 'model.pt') + >>> + >>> # save model to sharded checkpoints + >>> checkpoint_io.save_model(model, './checkpoints/', shard=True) + >>> + >>> # load model from sharded checkpoints + >>> model = checkpoint_io.load_model(model, './checkpoints/') + >>> + >>> # load optimizer from checkpoint + >>> optimizer = checkpoint_io.load_optimizer(optimizer, 'optimizer.pt') + >>> + >>> # save optimizer to checkpoint + >>> checkpoint_io.save_optimizer(optimizer, 'optimizer.pt') + + """ + + # ====================================== + # Abstract methods for implementation + # ====================================== + + @abstractmethod + def load_model(self, model: nn.Module, checkpoint: str, strict: bool = True): + """ + Load model from checkpoint. + + Args: + model (nn.Module): model to be loaded. + checkpoint (str): checkpoint path. This value is made compatiblity with the model checkpoints in the + mainstream model zoos such as Hugging Face and TIMM. The checkpoint path can be: + 1. a file path, e.g. 'model.pt' + 2. a path to a json file which defines the index to the sharded checkpoint + 3. a path to a folder containing a unique .index.json file for sharded checkpoint + strict (bool): whether to strictly enforce that the param name in + the checkpoint match the keys returned by this module's. + """ + pass + + @abstractmethod + def save_model(self, + model: nn.Module, + checkpoint: str, + prefix: str = None, + shard: bool = False, + size_per_shard: int = 1024): + """ + Save model to checkpoint. + + Examples: + >>> from colossalai.checkpoint_io import GeneralCheckpointIO + >>> checkpoint_io = CheckpointIO() + >>> + >>> # save model to a single file + >>> save_model(model, 'model.pt') + >>> + >>> # save model to a sharded checkpoint + >>> save_model(model, './checkpoints/', shard=True) + + Args: + model (nn.Module): model to be saved. + checkpoint: checkpoint path. The checkpoint path can be : + 1. a file path, e.g. 'model.pt' + 2. a directory path to save the sharded checkpoint, e.g. './checkpoints/' when shard = True. + shard: whether to shard the checkpoint. Default: False. If set to True, the checkpoint will be sharded into + multiple files. The model shards will be specificed by a `model.index.json` file. When shard = True, please ensure + that the checkpoint path is a directory path instead of a file path. + size_per_shard (int): size per shard in MB. Default: 1024. This value is only used when shard is set to True. + """ + pass + + @abstractmethod + def load_optimizer(self, optimizer: Optimizer, checkpoint: str): + """ + Load optimizer from checkpoint. + + Args: + optimizer (Optimizer): optimizer to be loaded. + checkpoint (str): checkpoint path. This value is made compatiblity with the model checkpoints in the + """ + pass + + @abstractmethod + def save_optimizer(self, optimizer: Optimizer, checkpoint: str, shard: bool = False, size_per_shard: int = 1024): + """ + Save optimizer to checkpoint. + + Args: + optimizer (Optimizer): optimizer to be saved. + checkpoint: checkpoint path. The checkpoint path can be : + 1. a file path, e.g. 'model.pt' + 2. a path to a json file which defines the index to the sharded checkpoint for the optimizer + 3. a path to a folder containing a unique .index.json file for sharded checkpoint + """ + pass + + # ============================================ + # methods for loading and saving lr scheduler + # as this is quite standard, there is no need + # to make them abstract + # ============================================ + + def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str): + """ + Save lr scheduler to checkpoint. + + Args: + lr_scheduler (LRScheduler): lr scheduler to be saved. + checkpoint: checkpoint path. The checkpoint path can only be a file path. + """ + torch.save(lr_scheduler.state_dict(), checkpoint) + + def load_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str): + """ + Load lr scheduler from checkpoint. + + Args: + lr_scheduler (LRScheduler): lr scheduler to be loaded. + checkpoint (str): the path for a single checkpoint file. + """ + state_dict = torch.load(checkpoint) + lr_scheduler.load_state_dict(state_dict) + + # ======================================== + # Helper functions for loading state dict + # ======================================== + + def get_sharded_checkpoint_index_file(self, checkpoint_path: Path): + """ + Get the index file path for a sharded checkpoint. + + Args: + checkpoint_path (Path): path to the checkpoint. + + Returns: + Path: path to the index file. + """ + if checkpoint_path.is_file(): + # check if it is .index.json + if checkpoint_path.name.endswith('.index.json'): + return checkpoint_path + else: + raise ValueError(f'Invalid checkpoint path: {checkpoint_path}. ') + elif checkpoint_path.is_dir(): + # check if there is only one a file ending with .index.json in this directory + index_files = list(checkpoint_path.glob('*.index.json')) + if len(index_files) == 1: + return index_files[0] + else: + raise ValueError(f'Found {len(index_files)} index files in {checkpoint_path}. ') + + def is_sharded_checkpoint(self, checkpoint_path: Path): + """ + Check whether the checkpoint is sharded. + + Args: + checkpoint (str): checkpoint path. + + Returns: + bool: whether the checkpoint is sharded. + """ + if checkpoint_path.is_file(): + # check if it is .index.json + if checkpoint_path.name.endswith('.index.json'): + return True + else: + return False + elif checkpoint_path.is_dir(): + # check if there is only one a file ending with .index.json in this directory + index_files = list(checkpoint_path.glob('*.index.json')) + if len(index_files) == 1: + return True + else: + raise ValueError(f'Found {len(index_files)} index files in {checkpoint_path}. ') + + def get_checkpoint_shard_filenames(self, index_file_path: Path): + """ + Get checkpoint shard filenames from a json file. + + Args: + index_file_path (Path): path to the json file. + + Returns: + list: checkpoint shard filenames. + """ + with open(str(index_file_path), 'r') as f: + shard_filenames = json.load(f) + + if "weight_map" in index: + index = index["weight_map"] + + checkpoint_root_path = index_file_path.absolute().parent + + # read the checkpoint file list from the json file and get a list of unique file names + checkpoint_files = sorted(list(set(index.values()))) + + # get the absolute paths for all checkpoint files + checkpoint_files = [checkpoint_root_path.joinpath(f) for f in checkpoint_files] + return shard_filenames + + def load_safetensors_state_dict(self, *args, **kwargs): + """ + Load safetensors state dict from checkpoint. + """ + # TODO(FrankLeeeee): support huggingface safetensors + raise NotImplementedError("This method is not implemented to support safe tensors") + + def load_state_dict(self, checkpoint_file_path: Path): + """ + Load state dict from checkpoint. + + Args: + checkpoint_file_path (Path): path to the checkpoint file. + + Returns: + dict: state dict. + """ + return torch.load(str(checkpoint_file_path)) + + # ====================================== + # Helper functions for saving state dict + # ====================================== + + def save_safetensors_state_dict(self, *args, **kwargs): + """ + Save safetensors state dict to checkpoint. + """ + # TODO(FrankLeeeee): support huggingface safetensors + raise NotImplementedError("This method is not implemented to support safe tensors") + + def generate_checkpoint_shard_file_name(self, index: int, total_number: int, prefix: str = None): + """ + Generate checkpoint shard file name. + + Args: + index (int): index of the shard. + total_number (int): total number of shards. + prefix (str): prefix of the shard file name. Default: None. + """ + if prefix is None: + return f"{index}-of-{total_number}.bin" + else: + return f"{prefix}-{index}-of-{total_number}.bin" + + def save_checkpoint(self, state_dict: dict, checkpoint_file_path: Path): + """ + Save state dict to checkpoint. + + Args: + state_dict (dict): state dict. + checkpoint_file_path (Path): path to the checkpoint file. + """ + torch.save(state_dict, str(checkpoint_file_path)) + + def save_state_dict_as_shard(self, state_dict: dict, index: int, total_number: int, prefix: str, + checkpoint_path: Path): + """ + Save state dict as shard. + + Args: + state_dict (dict): state dict. + checkpoint_path (Path): path to the checkpoint file. + """ + # generate the shard name + shard_file_name = self.generate_checkpoint_shard_file_name(index, total_number, prefix) + shard_file_path = checkpoint_path.joinpath(shard_file_name) + + # save the shard + self.save_checkpoint(state_dict, shard_file_path) + + def calculate_param_size(self, param: torch.Tensor): + """ + Calculate the size of a parameter in MB. Used to compute whether a group of params exceed the shard size. + If so, a new shard should be created. + + ArgsL + param (torch.Tensor): parameter tensor. + """ + # TODO(FrankLeeeee): check if this tensor is a DTensor, compute its global size if so + return param.numel() * param.element_size() / 1024 / 1024 + + +class ShardCheckpointIndexFile: + """ + This class is a data structure to keep the content in the index.json file for sharded checkpoint. + + Example: + >>> index = ShardCheckpointIndexFile() + >>> index.load('index.json') + >>> index.append_metadata('model_type', 'bert') + >>> index.append_weight_map('bert.embeddings.word_embeddings.weight', 'bert.embeddings.word_embeddings.weight-0-of-2.bin') + >>> index.export('index.json') + """ + + def __init__(self) -> None: + self.metadata: dict = dict() + self.weight_map: dict = dict() + + def load(self, json_path: str): + """ + Load the index file from a json file. + + Args: + json_path (str): path to the json file. + """ + # load the json file + with open(json_path, 'r') as f: + index = json.load(f) + + # assign attributes if exists + if "metadata" in index: + self.metadata = index["metadata"] + if "weight_map" in index: + self.weight_map = index["weight_map"] + + def export(self, json_path: str): + """ + Export the index file to a json file. + + Args: + json_path (str): path to the json file. + """ + # create the index file + index = dict() + index["metadata"] = self.metadata + index["weight_map"] = self.weight_map + + # export the index file + with open(json_path, 'w') as f: + json.dump(index, f, indent=4) + + def append_weight_map(self, param_name: str, shard_file: str): + """ + Append a weight map entry to the index file. + + Args: + param_name (str): name of the parameter. + shard_file (str): name of the shard file. + """ + self.weight_map[param_name] = shard_file + + def append_meta_data(self, name: str, val: Any): + """ + Append a metadata entry to the index file. + + Args: + name (str): name of the metadata. + val (Any): value of the metadata. + """ + self.metadata[name] = val diff --git a/colossalai/checkpoint_io/general_checkpoint_io.py b/colossalai/checkpoint_io/general_checkpoint_io.py new file mode 100644 index 000000000..0a3636655 --- /dev/null +++ b/colossalai/checkpoint_io/general_checkpoint_io.py @@ -0,0 +1,66 @@ +from pathlib import Path + +import torch.nn as nn +from torch.optim import Optimizer + +from .checkpoint_io_base import CheckpointIO + +__all__ = ['GeneralCheckpointIO'] + + +class GeneralCheckpointIO(CheckpointIO): + + def load_model(self, model: nn.Module, checkpoint: str, strict: bool = True): + checkpoint = Path(checkpoint) + is_sharded = self.is_sharded_checkpoint(checkpoint) + + if not is_sharded: + checkpoint = self.load_state_dict(checkpoint) + model.load_state_dict(checkpoint, strict=strict) + else: + # find the index file + checkpoint_path = Path(checkpoint) + index_file_path = self.get_sharded_checkpoint_index_file(checkpoint_path) + + # iterate over the shard checkpoint files + # and load each + shard_files = self.get_checkpoint_shard_filenames(index_file_path) + for shard_file in shard_files: + shard_checkpoint = self.load_state_dict(shard_file) + model.load_state_dict(shard_checkpoint, strict=strict) + + return model + + def save_model(self, + model: nn.Module, + checkpoint: str, + prefix: str = None, + shard: bool = False, + size_per_shard: int = 1024): + checkpoint = Path(checkpoint) + if shard: + # TODO(FrankLeeeee): implement checkpoint saving to sharded checkpoint + raise NotImplementedError("Not implemented yet") + else: + self.save_checkpoint(model.state_dict(), checkpoint) + + def load_optimizer(self, optimizer: Optimizer, checkpoint: str): + checkpoint = Path(checkpoint) + is_sharded = self.is_sharded_checkpoint(checkpoint) + + if not is_sharded: + checkpoint = self.load_state_dict(checkpoint) + optimizer.load_state_dict(checkpoint) + else: + # TODO(FrankLeeeee): implement checkpoint loading from sharded checkpoint + # This is not an urgent feature, so we can leave it for later + # let's implement this when we test large-scale models + pass + return optimizer + + def save_optimizer(self, optimizer: Optimizer, checkpoint: str, shard: bool = False, size_per_shard: int = 1024): + if shard: + # TODO(FrankLeeeee): implement checkpoint saving to sharded checkpoint + pass + else: + self.save_checkpoint(optimizer.state_dict(), checkpoint) diff --git a/tests/test_checkpoint_io/test_general_checkpoint_io.py b/tests/test_checkpoint_io/test_general_checkpoint_io.py new file mode 100644 index 000000000..48376aaa8 --- /dev/null +++ b/tests/test_checkpoint_io/test_general_checkpoint_io.py @@ -0,0 +1,70 @@ +import tempfile + +import torch +from torch.optim import Adam +from torchvision.models import resnet18 + +from colossalai.checkpoint_io import GeneralCheckpointIO + +# ======== +# Note: +# 1. due to checkpoint IO can be quite slow if tested with all models, we will only test on resnet for now +# 2. we will test on both sharded and unsharded checkpoints +# 3. TODO(FrankLeeeee): implement sharded checkpoint and test it +# ======== + + +def test_unsharded_checkpoint(): + # create a model and optimizer + model = resnet18() + optimizer = Adam(model.parameters(), lr=0.001) + + # create test data sample + x = torch.randn(1, 3, 224, 224) + + # run fwd and bwd + y = model(x) + loss = y.sum() + loss.backward() + optimizer.step() + + # create a temp file for checkpoint + model_ckpt_tempfile = tempfile.NamedTemporaryFile() + optimizer_ckpt_tempfile = tempfile.NamedTemporaryFile() + + # save the model and optimizer + ckpt_io = GeneralCheckpointIO() + ckpt_io.save_model(model, model_ckpt_tempfile.name) + ckpt_io.save_optimizer(optimizer, optimizer_ckpt_tempfile.name) + + # create new model + new_model = resnet18() + new_optimizer = Adam(new_model.parameters(), lr=0.001) + + # load the model and optimizer + new_model = ckpt_io.load_model(new_model, model_ckpt_tempfile.name) + new_optimizer = ckpt_io.load_optimizer(new_optimizer, optimizer_ckpt_tempfile.name) + + # do recursive check for the optimizer state dict + # if the value is a dict, compare its values + # if the value is a list, comapre all elements one-by-one + # if the value is a torch.Tensor, use torch.equal + # otherwise use assertEqual + def recursive_check(d1, d2): + for k, v in d1.items(): + if isinstance(v, dict): + recursive_check(v, d2[k]) + elif isinstance(v, list): + for i in range(len(v)): + if isinstance(v[i], torch.Tensor): + assert torch.equal(v[i], d2[k][i]) + else: + assert v[i] == d2[k][i] + elif isinstance(v, torch.Tensor): + assert torch.equal(v, d2[k]) + else: + assert v == d2[k] + + # check for model and optimizer state dict recursively + recursive_check(model.state_dict(), new_model.state_dict()) + recursive_check(optimizer.state_dict(), new_optimizer.state_dict())