2023-03-23 02:53:17 +00:00
|
|
|
from abc import ABC, abstractmethod
|
|
|
|
from pathlib import Path
|
2023-05-18 12:05:59 +00:00
|
|
|
from typing import Optional, Union
|
2023-03-23 02:53:17 +00:00
|
|
|
|
|
|
|
import torch
|
|
|
|
import torch.nn as nn
|
|
|
|
from torch.optim import Optimizer
|
|
|
|
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
|
|
|
|
|
2023-03-27 02:24:14 +00:00
|
|
|
from colossalai.interface import ModelWrapper
|
|
|
|
|
2023-04-04 07:23:01 +00:00
|
|
|
from .utils import has_index_file
|
|
|
|
|
|
|
|
__all__ = ['CheckpointIO']
|
2023-03-23 02:53:17 +00:00
|
|
|
|
|
|
|
|
|
|
|
class CheckpointIO(ABC):
|
|
|
|
"""
|
|
|
|
CheckpointIO is the base class for all checkpoint IO classes. It defines the interface for checkpoint IO.
|
|
|
|
|
|
|
|
|
|
|
|
Examples:
|
|
|
|
>>> from colossalai.checkpoint_io import GeneralCheckpointIO
|
|
|
|
>>> checkpoint_io = CheckpointIO()
|
|
|
|
>>>
|
|
|
|
>>> # load model from checkpoint
|
|
|
|
>>> model = checkpoint_io.load_model(model, 'model.pt')
|
|
|
|
>>>
|
2023-04-04 07:23:01 +00:00
|
|
|
>>> # save model to checkpoint, any distributed tensor is gathered by default
|
2023-03-23 02:53:17 +00:00
|
|
|
>>> checkpoint_io.save_model(model, 'model.pt')
|
|
|
|
>>>
|
2023-04-04 07:23:01 +00:00
|
|
|
>>> # if the model contains distributed tensor, and you don't want to gather it
|
|
|
|
>>> # each rank will save its own shard of the distributed tensor
|
|
|
|
>>> checkpoint_io.save_model(model, 'model.pt', gather_dtensor=False)
|
|
|
|
>>>
|
2023-03-23 02:53:17 +00:00
|
|
|
>>> # save model to sharded checkpoints
|
|
|
|
>>> checkpoint_io.save_model(model, './checkpoints/', shard=True)
|
|
|
|
>>>
|
2023-04-04 07:23:01 +00:00
|
|
|
>>> # save model to sharded and assume we don't want to gather distributed tensors
|
|
|
|
>>> checkpoint_io.save_model(model, './checkpoints/', shard=True, gather_dtensor=False)
|
|
|
|
>>>
|
|
|
|
>>> # Note:
|
|
|
|
>>> # 1. we don't support loading from distributed tensors, conversion from distributed tensors
|
|
|
|
>>> # checkpoints to full tensor checkpoint should be done offline via our CLI
|
|
|
|
>>> # 2. you don't have to specify whether the model is sharded or not when loading the model
|
|
|
|
>>> # as it will be automatically detected
|
|
|
|
>>>
|
2023-03-23 02:53:17 +00:00
|
|
|
>>> # load model from sharded checkpoints
|
|
|
|
>>> model = checkpoint_io.load_model(model, './checkpoints/')
|
|
|
|
>>>
|
2023-04-04 07:23:01 +00:00
|
|
|
>>> # load model from unsharded checkpoints
|
|
|
|
>>> model = checkpoint_io.load_model(model, './checkpoints/')
|
|
|
|
>>>
|
2023-03-23 02:53:17 +00:00
|
|
|
>>> # load optimizer from checkpoint
|
|
|
|
>>> optimizer = checkpoint_io.load_optimizer(optimizer, 'optimizer.pt')
|
|
|
|
>>>
|
|
|
|
>>> # save optimizer to checkpoint
|
|
|
|
>>> checkpoint_io.save_optimizer(optimizer, 'optimizer.pt')
|
|
|
|
"""
|
|
|
|
|
|
|
|
# ======================================
|
2023-03-27 02:24:14 +00:00
|
|
|
# Public methods
|
2023-03-23 02:53:17 +00:00
|
|
|
# ======================================
|
2023-03-27 02:24:14 +00:00
|
|
|
def load_model(self,
|
|
|
|
model: Union[nn.Module, ModelWrapper],
|
|
|
|
checkpoint: str,
|
|
|
|
strict: bool = True) -> Union[nn.Module, ModelWrapper]:
|
2023-03-23 02:53:17 +00:00
|
|
|
"""
|
|
|
|
Load model from checkpoint.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
model (nn.Module): model to be loaded.
|
2023-04-26 03:38:43 +00:00
|
|
|
checkpoint (str): checkpoint path. This value is made compatibility with the model checkpoints in the
|
2023-03-23 02:53:17 +00:00
|
|
|
mainstream model zoos such as Hugging Face and TIMM. The checkpoint path can be:
|
|
|
|
1. a file path, e.g. 'model.pt'
|
|
|
|
2. a path to a json file which defines the index to the sharded checkpoint
|
|
|
|
3. a path to a folder containing a unique .index.json file for sharded checkpoint
|
2023-04-04 07:23:01 +00:00
|
|
|
Distributed tensors cannot be loaded directly unless gathered offline via our CLI.
|
2023-03-23 02:53:17 +00:00
|
|
|
strict (bool): whether to strictly enforce that the param name in
|
|
|
|
the checkpoint match the keys returned by this module's.
|
|
|
|
"""
|
2023-04-04 07:23:01 +00:00
|
|
|
# since we only support loaded sharded and unsharded weight format
|
|
|
|
# containing no distributed tensors, dtensor -> full tensor conversion
|
|
|
|
# should be done offline via our CLI
|
|
|
|
# the existence of index file means it is a sharded checkpoint
|
|
|
|
index_file_exists, index_file_path = has_index_file(checkpoint)
|
2023-05-18 12:05:59 +00:00
|
|
|
|
2023-04-04 07:23:01 +00:00
|
|
|
# return the origin model instead of the unwrapped model
|
2023-03-27 02:24:14 +00:00
|
|
|
origin_model = model
|
|
|
|
|
|
|
|
if isinstance(model, ModelWrapper):
|
|
|
|
model = model.unwrap()
|
|
|
|
|
2023-04-04 07:23:01 +00:00
|
|
|
if index_file_exists:
|
|
|
|
self.load_sharded_model(model, index_file_path, strict)
|
2023-03-27 02:24:14 +00:00
|
|
|
else:
|
2023-04-04 07:23:01 +00:00
|
|
|
self.load_unsharded_model(model, checkpoint, strict)
|
2023-03-27 02:24:14 +00:00
|
|
|
|
|
|
|
return origin_model
|
2023-03-23 02:53:17 +00:00
|
|
|
|
|
|
|
def save_model(self,
|
2023-03-27 02:24:14 +00:00
|
|
|
model: Union[nn.Module, ModelWrapper],
|
2023-03-23 02:53:17 +00:00
|
|
|
checkpoint: str,
|
|
|
|
shard: bool = False,
|
2023-04-04 07:23:01 +00:00
|
|
|
gather_dtensor: bool = True,
|
2023-06-15 07:21:26 +00:00
|
|
|
prefix: str = None,
|
2023-04-04 07:23:01 +00:00
|
|
|
size_per_shard: int = 1024,
|
|
|
|
use_safetensors: bool = False):
|
2023-03-23 02:53:17 +00:00
|
|
|
"""
|
|
|
|
Save model to checkpoint.
|
|
|
|
|
|
|
|
Examples:
|
|
|
|
>>> from colossalai.checkpoint_io import GeneralCheckpointIO
|
|
|
|
>>> checkpoint_io = CheckpointIO()
|
|
|
|
>>>
|
|
|
|
>>> # save model to a single file
|
|
|
|
>>> save_model(model, 'model.pt')
|
|
|
|
>>>
|
|
|
|
>>> # save model to a sharded checkpoint
|
|
|
|
>>> save_model(model, './checkpoints/', shard=True)
|
|
|
|
|
|
|
|
Args:
|
|
|
|
model (nn.Module): model to be saved.
|
2023-03-27 02:24:14 +00:00
|
|
|
checkpoint (str): checkpoint path. The checkpoint path can be :
|
2023-03-23 02:53:17 +00:00
|
|
|
1. a file path, e.g. 'model.pt'
|
|
|
|
2. a directory path to save the sharded checkpoint, e.g. './checkpoints/' when shard = True.
|
2023-03-27 02:24:14 +00:00
|
|
|
shard (bool): whether to shard the checkpoint. Default: False. If set to True, the checkpoint will be sharded into
|
2023-04-26 03:38:43 +00:00
|
|
|
multiple files. The model shards will be specified by a `model.index.json` file. When shard = True, please ensure
|
2023-03-23 02:53:17 +00:00
|
|
|
that the checkpoint path is a directory path instead of a file path.
|
2023-04-04 07:23:01 +00:00
|
|
|
gather_dtensor (bool): whether to gather the distributed tensor to the first device. Default: True.
|
2023-06-15 07:21:26 +00:00
|
|
|
prefix (str): If specified, weights are saved in the format pytorch_model.<prefix>.bin. Default: None.
|
2023-03-27 02:24:14 +00:00
|
|
|
size_per_shard (int): size per shard in MB. Default: 1024. This value is only used when shard = True.
|
2023-04-04 07:23:01 +00:00
|
|
|
use_safetensors (bool): whether to use safe tensors. Default: False. If set to True, the checkpoint will be saved
|
2023-03-23 02:53:17 +00:00
|
|
|
"""
|
|
|
|
|
2023-03-27 02:24:14 +00:00
|
|
|
if isinstance(model, ModelWrapper):
|
|
|
|
model = model.unwrap()
|
|
|
|
|
|
|
|
if shard:
|
2023-06-15 07:21:26 +00:00
|
|
|
self.save_sharded_model(model, checkpoint, gather_dtensor, prefix, size_per_shard, use_safetensors)
|
2023-03-27 02:24:14 +00:00
|
|
|
else:
|
2023-04-04 07:23:01 +00:00
|
|
|
self.save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors)
|
2023-03-27 02:24:14 +00:00
|
|
|
|
2023-06-15 07:21:26 +00:00
|
|
|
def load_optimizer(self, optimizer: Optimizer, checkpoint: str, prefix: str = None, size_per_shard: int = 1024):
|
2023-03-23 02:53:17 +00:00
|
|
|
"""
|
|
|
|
Load optimizer from checkpoint.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
optimizer (Optimizer): optimizer to be loaded.
|
2023-04-26 03:38:43 +00:00
|
|
|
checkpoint (str): checkpoint path. This value is made compatibility with the model checkpoints in the
|
2023-06-16 06:14:05 +00:00
|
|
|
prefix (str, optional): A prefix added to parameter and buffer
|
|
|
|
names to compose the keys in state_dict. Defaults to None.
|
|
|
|
size_per_shard (int, optional): Maximum size of checkpoint shard file in MB. This is useful only when ``shard=True``. Defaults to 1024.
|
2023-03-23 02:53:17 +00:00
|
|
|
"""
|
2023-04-04 07:23:01 +00:00
|
|
|
index_file_exists, index_file_path = has_index_file(checkpoint)
|
2023-03-23 02:53:17 +00:00
|
|
|
|
2023-04-04 07:23:01 +00:00
|
|
|
if Path(checkpoint).is_dir() and not index_file_exists:
|
|
|
|
# if the checkpoint is a directory and there is no index file, raise error
|
|
|
|
raise ValueError(f'Cannot find index file in {checkpoint}')
|
|
|
|
|
|
|
|
if index_file_exists:
|
|
|
|
# the existence of index file means it is a sharded checkpoint
|
2023-06-16 06:14:05 +00:00
|
|
|
self.load_sharded_optimizer(optimizer, index_file_path, prefix)
|
2023-03-27 02:24:14 +00:00
|
|
|
else:
|
2023-04-04 07:23:01 +00:00
|
|
|
self.load_unsharded_optimizer(optimizer, checkpoint)
|
2023-03-27 02:24:14 +00:00
|
|
|
|
|
|
|
def save_optimizer(self,
|
|
|
|
optimizer: Optimizer,
|
|
|
|
checkpoint: str,
|
|
|
|
shard: bool = False,
|
2023-04-04 07:23:01 +00:00
|
|
|
gather_dtensor=True,
|
2023-03-27 02:24:14 +00:00
|
|
|
prefix: str = None,
|
|
|
|
size_per_shard: int = 1024):
|
2023-03-23 02:53:17 +00:00
|
|
|
"""
|
2023-04-04 07:23:01 +00:00
|
|
|
Save optimizer to checkpoint. Optimizer states saving is not compatible with safetensors.
|
2023-03-23 02:53:17 +00:00
|
|
|
|
|
|
|
Args:
|
|
|
|
optimizer (Optimizer): optimizer to be saved.
|
2023-03-27 02:24:14 +00:00
|
|
|
checkpoint (str): checkpoint path. The checkpoint path can be :
|
2023-03-23 02:53:17 +00:00
|
|
|
1. a file path, e.g. 'model.pt'
|
|
|
|
2. a path to a json file which defines the index to the sharded checkpoint for the optimizer
|
|
|
|
3. a path to a folder containing a unique .index.json file for sharded checkpoint
|
2023-03-27 02:24:14 +00:00
|
|
|
shard (bool): whether to shard the checkpoint. Default: False. If set to True, the checkpoint will be sharded into
|
2023-04-26 03:38:43 +00:00
|
|
|
multiple files. The optimizer shards will be specified by a `optimizer.index.json` file.
|
2023-04-04 07:23:01 +00:00
|
|
|
gather_dtensor (bool): whether to gather the distributed tensor to the first device. Default: True.
|
2023-03-27 02:24:14 +00:00
|
|
|
prefix (str): prefix for the optimizer checkpoint when shard = True. Default: None.
|
|
|
|
size_per_shard (int): size per shard in MB. Default: 1024. This value is only used when shard is set to True.
|
|
|
|
"""
|
|
|
|
if shard:
|
2023-04-04 07:23:01 +00:00
|
|
|
self.save_sharded_optimizer(optimizer, checkpoint, gather_dtensor, prefix, size_per_shard)
|
2023-03-27 02:24:14 +00:00
|
|
|
else:
|
2023-04-04 07:23:01 +00:00
|
|
|
self.save_unsharded_optimizer(optimizer, checkpoint, gather_dtensor)
|
2023-03-27 02:24:14 +00:00
|
|
|
|
|
|
|
# ========================================================
|
|
|
|
# Abstract methods for model loading/saving implementation
|
|
|
|
# ========================================================
|
|
|
|
@abstractmethod
|
2023-04-04 07:23:01 +00:00
|
|
|
def load_sharded_model(self, model: nn.Module, index_file_path: str, strict: bool):
|
2023-03-27 02:24:14 +00:00
|
|
|
"""
|
|
|
|
Load model from sharded checkpoint.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
model (nn.Module): model to be loaded.
|
2023-04-04 07:23:01 +00:00
|
|
|
index_file_path (str): checkpoint path. It should be path to the .index.json file or a path to a directory which contains a .index.json file.
|
|
|
|
strict (bool): whether to strictly enforce that the param name in
|
|
|
|
the checkpoint match the keys returned by this module's.
|
2023-03-27 02:24:14 +00:00
|
|
|
"""
|
|
|
|
pass
|
|
|
|
|
|
|
|
@abstractmethod
|
2023-04-04 07:23:01 +00:00
|
|
|
def load_unsharded_model(self, model: nn.Module, checkpoint: str, strict: bool):
|
2023-03-27 02:24:14 +00:00
|
|
|
"""
|
|
|
|
Load model from unsharded checkpoint.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
model (nn.Module): model to be loaded.
|
|
|
|
checkpoint (str): checkpoint path. It should be a single file path pointing to a model weight binary.
|
|
|
|
strict (bool): whether to strictly enforce that the param name in
|
|
|
|
the checkpoint match the keys returned by this module's.
|
|
|
|
"""
|
|
|
|
pass
|
|
|
|
|
|
|
|
@abstractmethod
|
2023-06-15 07:21:26 +00:00
|
|
|
def save_sharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, prefix: Optional[str],
|
2023-04-04 07:23:01 +00:00
|
|
|
size_per_shard: int, use_safetensors: bool):
|
2023-03-27 02:24:14 +00:00
|
|
|
"""
|
|
|
|
Save model to sharded checkpoint.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
model (nn.Module): model to be saved.
|
2023-04-04 07:23:01 +00:00
|
|
|
checkpoint (str): checkpoint path. It should be a directory path.
|
|
|
|
gather_dtensor (bool): whether to gather the distributed tensor to the first device.
|
2023-03-27 02:24:14 +00:00
|
|
|
prefix (str): prefix for the model checkpoint.
|
|
|
|
size_per_shard (int): size per shard in MB.
|
2023-04-04 07:23:01 +00:00
|
|
|
use_safetensors (bool): whether to use safe tensors.
|
2023-03-27 02:24:14 +00:00
|
|
|
"""
|
|
|
|
pass
|
|
|
|
|
|
|
|
@abstractmethod
|
2023-04-04 07:23:01 +00:00
|
|
|
def save_unsharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, use_safetensors: bool):
|
2023-03-27 02:24:14 +00:00
|
|
|
"""
|
|
|
|
Save model to unsharded checkpoint.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
model (nn.Module): model to be saved.
|
2023-04-04 07:23:01 +00:00
|
|
|
checkpoint (str): checkpoint path. It should be a single file path pointing to a model weight binary.
|
|
|
|
gather_dtensor (bool): whether to gather the distributed tensor to the first device.
|
|
|
|
use_safetensors (bool): whether to use safe tensors.
|
2023-03-27 02:24:14 +00:00
|
|
|
"""
|
|
|
|
pass
|
|
|
|
|
|
|
|
# ========================================================
|
|
|
|
# Abstract methods for optimizer loading/saving implementation
|
|
|
|
# ========================================================
|
|
|
|
|
|
|
|
@abstractmethod
|
2023-06-16 06:14:05 +00:00
|
|
|
def load_sharded_optimizer(self, optimizer: Optimizer, index_file_path: str, prefix: str):
|
2023-03-27 02:24:14 +00:00
|
|
|
"""
|
|
|
|
Load optimizer from sharded checkpoint.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
optimizer (Optimizer): optimizer to be loaded.
|
2023-04-04 07:23:01 +00:00
|
|
|
index_file_path (str): checkpoint path. It should be path to the .index.json file or a path to a directory which contains a .index.json file.
|
2023-03-27 02:24:14 +00:00
|
|
|
prefix (str): prefix for the optimizer checkpoint.
|
|
|
|
"""
|
|
|
|
pass
|
|
|
|
|
|
|
|
@abstractmethod
|
|
|
|
def load_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path):
|
|
|
|
"""
|
|
|
|
Load optimizer from unsharded checkpoint.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
optimizer (Optimizer): optimizer to be loaded.
|
|
|
|
checkpoint (str): checkpoint path. It should be a single file path pointing to a model weight binary.
|
|
|
|
"""
|
|
|
|
pass
|
|
|
|
|
|
|
|
@abstractmethod
|
2023-04-04 07:23:01 +00:00
|
|
|
def save_sharded_optimizer(self, optimizer: Optimizer, checkpoint: Path, gather_dtensor: bool, prefix: str,
|
|
|
|
size_per_shard: int):
|
2023-03-27 02:24:14 +00:00
|
|
|
"""
|
|
|
|
Save optimizer to sharded checkpoint.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
optimizer (Optimizer): optimizer to be saved.
|
|
|
|
checkpoint (Path): checkpoint path. It should be a directory path.
|
2023-04-04 07:23:01 +00:00
|
|
|
gather_dtensor (bool): whether to gather the distributed tensor to the first device.
|
2023-03-27 02:24:14 +00:00
|
|
|
prefix (str): prefix for the optimizer checkpoint.
|
|
|
|
size_per_shard (int): size per shard in MB.
|
|
|
|
"""
|
|
|
|
pass
|
|
|
|
|
|
|
|
@abstractmethod
|
2023-04-04 07:23:01 +00:00
|
|
|
def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path, gather_dtensor: bool):
|
2023-03-27 02:24:14 +00:00
|
|
|
"""
|
|
|
|
Save optimizer to unsharded checkpoint.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
optimizer (Optimizer): optimizer to be saved.
|
|
|
|
checkpoint (str): checkpoint path. It should be a single file path pointing to a model weight binary.
|
2023-04-04 07:23:01 +00:00
|
|
|
gather_dtensor (bool): whether to gather the distributed tensor to the first device.
|
2023-03-23 02:53:17 +00:00
|
|
|
"""
|
|
|
|
pass
|
|
|
|
|
|
|
|
# ============================================
|
|
|
|
# methods for loading and saving lr scheduler
|
|
|
|
# as this is quite standard, there is no need
|
|
|
|
# to make them abstract
|
|
|
|
# ============================================
|
|
|
|
def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str):
|
|
|
|
"""
|
|
|
|
Save lr scheduler to checkpoint.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
lr_scheduler (LRScheduler): lr scheduler to be saved.
|
|
|
|
checkpoint: checkpoint path. The checkpoint path can only be a file path.
|
|
|
|
"""
|
|
|
|
torch.save(lr_scheduler.state_dict(), checkpoint)
|
|
|
|
|
|
|
|
def load_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str):
|
|
|
|
"""
|
|
|
|
Load lr scheduler from checkpoint.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
lr_scheduler (LRScheduler): lr scheduler to be loaded.
|
|
|
|
checkpoint (str): the path for a single checkpoint file.
|
|
|
|
"""
|
|
|
|
state_dict = torch.load(checkpoint)
|
|
|
|
lr_scheduler.load_state_dict(state_dict)
|