from abc import ABC, abstractmethod from pathlib import Path from typing import Optional, Union import torch import torch.nn as nn from torch.optim import Optimizer from torch.optim.lr_scheduler import _LRScheduler as LRScheduler from colossalai.interface import ModelWrapper from .utils import SAFE_WEIGHTS_NAME, WEIGHTS_NAME, has_index_file __all__ = ["CheckpointIO"] class CheckpointIO(ABC): """ CheckpointIO is the base class for all checkpoint IO classes. It defines the interface for checkpoint IO. Examples: >>> from colossalai.checkpoint_io import GeneralCheckpointIO >>> checkpoint_io = CheckpointIO() >>> >>> # load model from checkpoint >>> model = checkpoint_io.load_model(model, 'model.pt') >>> >>> # save model to checkpoint, any distributed tensor is gathered by default >>> checkpoint_io.save_model(model, 'model.pt') >>> >>> # if the model contains distributed tensor, and you don't want to gather it >>> # each rank will save its own shard of the distributed tensor >>> checkpoint_io.save_model(model, 'model.pt', gather_dtensor=False) >>> >>> # save model to sharded checkpoints >>> checkpoint_io.save_model(model, './checkpoints/', shard=True) >>> >>> # save model to sharded and assume we don't want to gather distributed tensors >>> checkpoint_io.save_model(model, './checkpoints/', shard=True, gather_dtensor=False) >>> >>> # Note: >>> # 1. we don't support loading from distributed tensors, conversion from distributed tensors >>> # checkpoints to full tensor checkpoint should be done offline via our CLI >>> # 2. you don't have to specify whether the model is sharded or not when loading the model >>> # as it will be automatically detected >>> >>> # load model from sharded checkpoints >>> model = checkpoint_io.load_model(model, './checkpoints/') >>> >>> # load model from unsharded checkpoints >>> model = checkpoint_io.load_model(model, './checkpoints/') >>> >>> # load optimizer from checkpoint >>> optimizer = checkpoint_io.load_optimizer(optimizer, 'optimizer.pt') >>> >>> # save optimizer to checkpoint >>> checkpoint_io.save_optimizer(optimizer, 'optimizer.pt') """ # ====================================== # Public methods # ====================================== def load_model( self, model: Union[nn.Module, ModelWrapper], checkpoint: str, strict: bool = True ) -> Union[nn.Module, ModelWrapper]: """ Load model from checkpoint. Args: model (nn.Module): model to be loaded. checkpoint (str): checkpoint path. This value is made compatibility with the model checkpoints in the mainstream model zoos such as Hugging Face and TIMM. The checkpoint path can be: 1. a file path, e.g. 'model.pt' 2. a path to a json file which defines the index to the sharded checkpoint 3. a path to a folder containing a unique .index.json file for sharded checkpoint Distributed tensors cannot be loaded directly unless gathered offline via our CLI. strict (bool): whether to strictly enforce that the param name in the checkpoint match the keys returned by this module's. """ # since we only support loaded sharded and unsharded weight format # containing no distributed tensors, dtensor -> full tensor conversion # should be done offline via our CLI # the existence of index file means it is a sharded checkpoint index_file_exists, index_file_path = has_index_file(checkpoint) # return the origin model instead of the unwrapped model origin_model = model if index_file_exists: self.load_sharded_model(model, index_file_path, strict) else: path = Path(checkpoint, SAFE_WEIGHTS_NAME) if path.is_file(): self.load_unsharded_model(model, str(path), strict) else: path = Path(checkpoint, WEIGHTS_NAME) if path.is_file(): self.load_unsharded_model(model, str(path), strict) else: self.load_unsharded_model(model, checkpoint, strict) return origin_model def save_model( self, model: Union[nn.Module, ModelWrapper], checkpoint: str, shard: bool = False, gather_dtensor: bool = True, prefix: str = None, size_per_shard: int = 1024, use_safetensors: bool = False, ): """ Save model to checkpoint. Examples: >>> from colossalai.checkpoint_io import GeneralCheckpointIO >>> checkpoint_io = CheckpointIO() >>> >>> # save model to a single file >>> save_model(model, 'model.pt') >>> >>> # save model to a sharded checkpoint >>> save_model(model, './checkpoints/', shard=True) Args: model (nn.Module): model to be saved. checkpoint (str): checkpoint path. The checkpoint path can be : 1. a file path, e.g. 'model.pt' 2. a directory path to save the sharded checkpoint, e.g. './checkpoints/' when shard = True. shard (bool): whether to shard the checkpoint. Default: False. If set to True, the checkpoint will be sharded into multiple files. The model shards will be specified by a `model.index.json` file. When shard = True, please ensure that the checkpoint path is a directory path instead of a file path. gather_dtensor (bool): whether to gather the distributed tensor to the first device. Default: True. prefix (str): If specified, weights are saved in the format pytorch_model..bin. Default: None. size_per_shard (int): size per shard in MB. Default: 1024. This value is only used when shard = True. use_safetensors (bool): whether to use safe tensors. Default: False. If set to True, the checkpoint will be saved """ if shard: self.save_sharded_model(model, checkpoint, gather_dtensor, prefix, size_per_shard, use_safetensors) else: self.save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors) def load_optimizer(self, optimizer: Optimizer, checkpoint: str, prefix: str = None, size_per_shard: int = 1024): """ Load optimizer from checkpoint. Args: optimizer (Optimizer): optimizer to be loaded. checkpoint (str): checkpoint path. This value is made compatibility with the model checkpoints in the prefix (str, optional): A prefix added to parameter and buffer names to compose the keys in state_dict. Defaults to None. size_per_shard (int, optional): Maximum size of checkpoint shard file in MB. This is useful only when ``shard=True``. Defaults to 1024. """ index_file_exists, index_file_path = has_index_file(checkpoint) if Path(checkpoint).is_dir() and not index_file_exists: # if the checkpoint is a directory and there is no index file, raise error raise ValueError(f"Cannot find index file in {checkpoint}") if index_file_exists: # the existence of index file means it is a sharded checkpoint self.load_sharded_optimizer(optimizer, index_file_path, prefix) else: self.load_unsharded_optimizer(optimizer, checkpoint) def save_optimizer( self, optimizer: Optimizer, checkpoint: str, shard: bool = False, gather_dtensor=True, prefix: str = None, size_per_shard: int = 1024, ): """ Save optimizer to checkpoint. Optimizer states saving is not compatible with safetensors. Args: optimizer (Optimizer): optimizer to be saved. checkpoint (str): checkpoint path. The checkpoint path can be : 1. a file path, e.g. 'model.pt' 2. a path to a json file which defines the index to the sharded checkpoint for the optimizer 3. a path to a folder containing a unique .index.json file for sharded checkpoint shard (bool): whether to shard the checkpoint. Default: False. If set to True, the checkpoint will be sharded into multiple files. The optimizer shards will be specified by a `optimizer.index.json` file. gather_dtensor (bool): whether to gather the distributed tensor to the first device. Default: True. prefix (str): prefix for the optimizer checkpoint when shard = True. Default: None. size_per_shard (int): size per shard in MB. Default: 1024. This value is only used when shard is set to True. """ if shard: self.save_sharded_optimizer(optimizer, checkpoint, gather_dtensor, prefix, size_per_shard) else: self.save_unsharded_optimizer(optimizer, checkpoint, gather_dtensor) # ======================================================== # Abstract methods for model loading/saving implementation # ======================================================== @abstractmethod def load_sharded_model(self, model: nn.Module, index_file_path: str, strict: bool): """ Load model from sharded checkpoint. Args: model (nn.Module): model to be loaded. index_file_path (str): checkpoint path. It should be path to the .index.json file or a path to a directory which contains a .index.json file. strict (bool): whether to strictly enforce that the param name in the checkpoint match the keys returned by this module's. """ @abstractmethod def load_unsharded_model(self, model: nn.Module, checkpoint: str, strict: bool): """ Load model from unsharded checkpoint. Args: model (nn.Module): model to be loaded. checkpoint (str): checkpoint path. It should be a single file path pointing to a model weight binary. strict (bool): whether to strictly enforce that the param name in the checkpoint match the keys returned by this module's. """ @abstractmethod def save_sharded_model( self, model: nn.Module, checkpoint: str, gather_dtensor: bool, prefix: Optional[str], size_per_shard: int, use_safetensors: bool, ): """ Save model to sharded checkpoint. Args: model (nn.Module): model to be saved. checkpoint (str): checkpoint path. It should be a directory path. gather_dtensor (bool): whether to gather the distributed tensor to the first device. prefix (str): prefix for the model checkpoint. size_per_shard (int): size per shard in MB. use_safetensors (bool): whether to use safe tensors. """ @abstractmethod def save_unsharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, use_safetensors: bool): """ Save model to unsharded checkpoint. Args: model (nn.Module): model to be saved. checkpoint (str): checkpoint path. It should be a single file path pointing to a model weight binary. gather_dtensor (bool): whether to gather the distributed tensor to the first device. use_safetensors (bool): whether to use safe tensors. """ # ======================================================== # Abstract methods for optimizer loading/saving implementation # ======================================================== @abstractmethod def load_sharded_optimizer(self, optimizer: Optimizer, index_file_path: str, prefix: str): """ Load optimizer from sharded checkpoint. Args: optimizer (Optimizer): optimizer to be loaded. index_file_path (str): checkpoint path. It should be path to the .index.json file or a path to a directory which contains a .index.json file. prefix (str): prefix for the optimizer checkpoint. """ @abstractmethod def load_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path): """ Load optimizer from unsharded checkpoint. Args: optimizer (Optimizer): optimizer to be loaded. checkpoint (str): checkpoint path. It should be a single file path pointing to a model weight binary. """ @abstractmethod def save_sharded_optimizer( self, optimizer: Optimizer, checkpoint: Path, gather_dtensor: bool, prefix: str, size_per_shard: int ): """ Save optimizer to sharded checkpoint. Args: optimizer (Optimizer): optimizer to be saved. checkpoint (Path): checkpoint path. It should be a directory path. gather_dtensor (bool): whether to gather the distributed tensor to the first device. prefix (str): prefix for the optimizer checkpoint. size_per_shard (int): size per shard in MB. """ @abstractmethod def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path, gather_dtensor: bool): """ Save optimizer to unsharded checkpoint. Args: optimizer (Optimizer): optimizer to be saved. checkpoint (str): checkpoint path. It should be a single file path pointing to a model weight binary. gather_dtensor (bool): whether to gather the distributed tensor to the first device. """ # ============================================ # methods for loading and saving lr scheduler # as this is quite standard, there is no need # to make them abstract # ============================================ def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str): """ Save lr scheduler to checkpoint. Args: lr_scheduler (LRScheduler): lr scheduler to be saved. checkpoint: checkpoint path. The checkpoint path can only be a file path. """ torch.save(lr_scheduler.state_dict(), checkpoint) def load_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str): """ Load lr scheduler from checkpoint. Args: lr_scheduler (LRScheduler): lr scheduler to be loaded. checkpoint (str): the path for a single checkpoint file. """ state_dict = torch.load(checkpoint) lr_scheduler.load_state_dict(state_dict)