[checkpoint] refactored the API and added safetensors support (#3427)

* [checkpoint] refactored the API and added safetensors support

* polish code
pull/3442/head
Frank Lee 2023-04-04 15:23:01 +08:00 committed by GitHub
parent 26b7aac0be
commit 1beb85cc25
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 579 additions and 280 deletions

View File

@ -33,7 +33,7 @@ class TorchDDPCheckpointIO(GeneralCheckpointIO):
# the model should be unwrapped in self.load_model via ModelWrapper.unwrap # the model should be unwrapped in self.load_model via ModelWrapper.unwrap
return super().load_unsharded_model(model, checkpoint, strict=strict) return super().load_unsharded_model(model, checkpoint, strict=strict)
def save_unsharded_model(self, model: nn.Module, checkpoint: str): def save_unsharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool):
""" """
Save model to checkpoint but only on master process. Save model to checkpoint but only on master process.
""" """
@ -41,7 +41,7 @@ class TorchDDPCheckpointIO(GeneralCheckpointIO):
if self.coordinator.is_master(): if self.coordinator.is_master():
super().save_unsharded_model(model, checkpoint) super().save_unsharded_model(model, checkpoint)
def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: str): def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool):
""" """
Save optimizer to checkpoint but only on master process. Save optimizer to checkpoint but only on master process.
""" """

View File

@ -1,4 +1,5 @@
from .checkpoint_io_base import CheckpointIO, ShardCheckpointIndexFile from .checkpoint_io_base import CheckpointIO
from .general_checkpoint_io import GeneralCheckpointIO from .general_checkpoint_io import GeneralCheckpointIO
from .index_file import CheckpointIndexFile
__all__ = ['CheckpointIO', 'ShardCheckpointIndexFile', 'GeneralCheckpointIO'] __all__ = ['CheckpointIO', 'CheckpointIndexFile', 'GeneralCheckpointIO']

View File

@ -1,7 +1,6 @@
import json
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from pathlib import Path from pathlib import Path
from typing import Any, Union from typing import Union
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -10,7 +9,9 @@ from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from colossalai.interface import ModelWrapper from colossalai.interface import ModelWrapper
__all__ = ['CheckpointIO', 'ShardCheckpointIndexFile'] from .utils import has_index_file
__all__ = ['CheckpointIO']
class CheckpointIO(ABC): class CheckpointIO(ABC):
@ -25,15 +26,31 @@ class CheckpointIO(ABC):
>>> # load model from checkpoint >>> # load model from checkpoint
>>> model = checkpoint_io.load_model(model, 'model.pt') >>> model = checkpoint_io.load_model(model, 'model.pt')
>>> >>>
>>> # save model to checkpoint >>> # save model to checkpoint, any distributed tensor is gathered by default
>>> checkpoint_io.save_model(model, 'model.pt') >>> checkpoint_io.save_model(model, 'model.pt')
>>> >>>
>>> # if the model contains distributed tensor, and you don't want to gather it
>>> # each rank will save its own shard of the distributed tensor
>>> checkpoint_io.save_model(model, 'model.pt', gather_dtensor=False)
>>>
>>> # save model to sharded checkpoints >>> # save model to sharded checkpoints
>>> checkpoint_io.save_model(model, './checkpoints/', shard=True) >>> checkpoint_io.save_model(model, './checkpoints/', shard=True)
>>> >>>
>>> # save model to sharded and assume we don't want to gather distributed tensors
>>> checkpoint_io.save_model(model, './checkpoints/', shard=True, gather_dtensor=False)
>>>
>>> # Note:
>>> # 1. we don't support loading from distributed tensors, conversion from distributed tensors
>>> # checkpoints to full tensor checkpoint should be done offline via our CLI
>>> # 2. you don't have to specify whether the model is sharded or not when loading the model
>>> # as it will be automatically detected
>>>
>>> # load model from sharded checkpoints >>> # load model from sharded checkpoints
>>> model = checkpoint_io.load_model(model, './checkpoints/') >>> model = checkpoint_io.load_model(model, './checkpoints/')
>>> >>>
>>> # load model from unsharded checkpoints
>>> model = checkpoint_io.load_model(model, './checkpoints/')
>>>
>>> # load optimizer from checkpoint >>> # load optimizer from checkpoint
>>> optimizer = checkpoint_io.load_optimizer(optimizer, 'optimizer.pt') >>> optimizer = checkpoint_io.load_optimizer(optimizer, 'optimizer.pt')
>>> >>>
@ -58,21 +75,27 @@ class CheckpointIO(ABC):
1. a file path, e.g. 'model.pt' 1. a file path, e.g. 'model.pt'
2. a path to a json file which defines the index to the sharded checkpoint 2. a path to a json file which defines the index to the sharded checkpoint
3. a path to a folder containing a unique .index.json file for sharded checkpoint 3. a path to a folder containing a unique .index.json file for sharded checkpoint
Distributed tensors cannot be loaded directly unless gathered offline via our CLI.
strict (bool): whether to strictly enforce that the param name in strict (bool): whether to strictly enforce that the param name in
the checkpoint match the keys returned by this module's. the checkpoint match the keys returned by this module's.
""" """
# since we only support loaded sharded and unsharded weight format
# containing no distributed tensors, dtensor -> full tensor conversion
# should be done offline via our CLI
# the existence of index file means it is a sharded checkpoint
ckpt_path = Path(checkpoint) ckpt_path = Path(checkpoint)
is_sharded = self.is_sharded_checkpoint(ckpt_path) index_file_exists, index_file_path = has_index_file(checkpoint)
# return the origin model instead of the unwrapped model
origin_model = model origin_model = model
if isinstance(model, ModelWrapper): if isinstance(model, ModelWrapper):
model = model.unwrap() model = model.unwrap()
if is_sharded: if index_file_exists:
self.load_sharded_model(model, ckpt_path, strict) self.load_sharded_model(model, index_file_path, strict)
else: else:
self.load_unsharded_model(model, ckpt_path, strict) self.load_unsharded_model(model, checkpoint, strict)
return origin_model return origin_model
@ -80,8 +103,10 @@ class CheckpointIO(ABC):
model: Union[nn.Module, ModelWrapper], model: Union[nn.Module, ModelWrapper],
checkpoint: str, checkpoint: str,
shard: bool = False, shard: bool = False,
gather_dtensor: bool = True,
prefix: str = None, prefix: str = None,
size_per_shard: int = 1024): size_per_shard: int = 1024,
use_safetensors: bool = False):
""" """
Save model to checkpoint. Save model to checkpoint.
@ -103,17 +128,19 @@ class CheckpointIO(ABC):
shard (bool): whether to shard the checkpoint. Default: False. If set to True, the checkpoint will be sharded into shard (bool): whether to shard the checkpoint. Default: False. If set to True, the checkpoint will be sharded into
multiple files. The model shards will be specificed by a `model.index.json` file. When shard = True, please ensure multiple files. The model shards will be specificed by a `model.index.json` file. When shard = True, please ensure
that the checkpoint path is a directory path instead of a file path. that the checkpoint path is a directory path instead of a file path.
gather_dtensor (bool): whether to gather the distributed tensor to the first device. Default: True.
prefix (str): prefix for the model checkpoint file name when shard=True. Default: None. prefix (str): prefix for the model checkpoint file name when shard=True. Default: None.
size_per_shard (int): size per shard in MB. Default: 1024. This value is only used when shard = True. size_per_shard (int): size per shard in MB. Default: 1024. This value is only used when shard = True.
use_safetensors (bool): whether to use safe tensors. Default: False. If set to True, the checkpoint will be saved
""" """
if isinstance(model, ModelWrapper): if isinstance(model, ModelWrapper):
model = model.unwrap() model = model.unwrap()
if shard: if shard:
self.save_sharded_model(model, checkpoint, prefix, size_per_shard) self.save_sharded_model(model, checkpoint, gather_dtensor, prefix, size_per_shard, use_safetensors)
else: else:
self.save_unsharded_model(model, checkpoint) self.save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors)
def load_optimizer(self, optimizer: Optimizer, checkpoint: str): def load_optimizer(self, optimizer: Optimizer, checkpoint: str):
""" """
@ -123,22 +150,27 @@ class CheckpointIO(ABC):
optimizer (Optimizer): optimizer to be loaded. optimizer (Optimizer): optimizer to be loaded.
checkpoint (str): checkpoint path. This value is made compatiblity with the model checkpoints in the checkpoint (str): checkpoint path. This value is made compatiblity with the model checkpoints in the
""" """
ckpt_path = Path(checkpoint) index_file_exists, index_file_path = has_index_file(checkpoint)
is_sharded = self.is_sharded_checkpoint(ckpt_path)
if is_sharded: if Path(checkpoint).is_dir() and not index_file_exists:
self.load_sharded_optimizer(optimizer, ckpt_path) # if the checkpoint is a directory and there is no index file, raise error
raise ValueError(f'Cannot find index file in {checkpoint}')
if index_file_exists:
# the existence of index file means it is a sharded checkpoint
self.load_sharded_optimizer(optimizer, index_file_path)
else: else:
self.load_unsharded_optimizer(optimizer, ckpt_path) self.load_unsharded_optimizer(optimizer, checkpoint)
def save_optimizer(self, def save_optimizer(self,
optimizer: Optimizer, optimizer: Optimizer,
checkpoint: str, checkpoint: str,
shard: bool = False, shard: bool = False,
gather_dtensor=True,
prefix: str = None, prefix: str = None,
size_per_shard: int = 1024): size_per_shard: int = 1024):
""" """
Save optimizer to checkpoint. Save optimizer to checkpoint. Optimizer states saving is not compatible with safetensors.
Args: Args:
optimizer (Optimizer): optimizer to be saved. optimizer (Optimizer): optimizer to be saved.
@ -148,30 +180,33 @@ class CheckpointIO(ABC):
3. a path to a folder containing a unique .index.json file for sharded checkpoint 3. a path to a folder containing a unique .index.json file for sharded checkpoint
shard (bool): whether to shard the checkpoint. Default: False. If set to True, the checkpoint will be sharded into shard (bool): whether to shard the checkpoint. Default: False. If set to True, the checkpoint will be sharded into
multiple files. The optimizer shards will be specificed by a `optimizer.index.json` file. multiple files. The optimizer shards will be specificed by a `optimizer.index.json` file.
gather_dtensor (bool): whether to gather the distributed tensor to the first device. Default: True.
prefix (str): prefix for the optimizer checkpoint when shard = True. Default: None. prefix (str): prefix for the optimizer checkpoint when shard = True. Default: None.
size_per_shard (int): size per shard in MB. Default: 1024. This value is only used when shard is set to True. size_per_shard (int): size per shard in MB. Default: 1024. This value is only used when shard is set to True.
""" """
if shard: if shard:
self.save_sharded_optimizer(optimizer, checkpoint, prefix, size_per_shard) self.save_sharded_optimizer(optimizer, checkpoint, gather_dtensor, prefix, size_per_shard)
else: else:
self.save_unsharded_optimizer(optimizer, checkpoint) self.save_unsharded_optimizer(optimizer, checkpoint, gather_dtensor)
# ======================================================== # ========================================================
# Abstract methods for model loading/saving implementation # Abstract methods for model loading/saving implementation
# ======================================================== # ========================================================
@abstractmethod @abstractmethod
def load_sharded_model(self, model: nn.Module, checkpoint: Path, strict: bool): def load_sharded_model(self, model: nn.Module, index_file_path: str, strict: bool):
""" """
Load model from sharded checkpoint. Load model from sharded checkpoint.
Args: Args:
model (nn.Module): model to be loaded. model (nn.Module): model to be loaded.
checkpoint (str): checkpoint path. It should be path to the .index.json file or a path to a directory which contains a .index.json file. index_file_path (str): checkpoint path. It should be path to the .index.json file or a path to a directory which contains a .index.json file.
strict (bool): whether to strictly enforce that the param name in
the checkpoint match the keys returned by this module's.
""" """
pass pass
@abstractmethod @abstractmethod
def load_unsharded_model(self, model: nn.Module, checkpoint: Path, strict: bool): def load_unsharded_model(self, model: nn.Module, checkpoint: str, strict: bool):
""" """
Load model from unsharded checkpoint. Load model from unsharded checkpoint.
@ -184,26 +219,31 @@ class CheckpointIO(ABC):
pass pass
@abstractmethod @abstractmethod
def save_sharded_model(self, model: nn.Module, checkpoint: Path, prefix: str, size_per_shard: int): def save_sharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, prefix: str,
size_per_shard: int, use_safetensors: bool):
""" """
Save model to sharded checkpoint. Save model to sharded checkpoint.
Args: Args:
model (nn.Module): model to be saved. model (nn.Module): model to be saved.
checkpoint (Path): checkpoint path. It should be a directory path. checkpoint (str): checkpoint path. It should be a directory path.
gather_dtensor (bool): whether to gather the distributed tensor to the first device.
prefix (str): prefix for the model checkpoint. prefix (str): prefix for the model checkpoint.
size_per_shard (int): size per shard in MB. size_per_shard (int): size per shard in MB.
use_safetensors (bool): whether to use safe tensors.
""" """
pass pass
@abstractmethod @abstractmethod
def save_unsharded_model(self, model: nn.Module, checkpoint: Path): def save_unsharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, use_safetensors: bool):
""" """
Save model to unsharded checkpoint. Save model to unsharded checkpoint.
Args: Args:
model (nn.Module): model to be saved. model (nn.Module): model to be saved.
checkpoint (Path): checkpoint path. It should be a single file path pointing to a model weight binary. checkpoint (str): checkpoint path. It should be a single file path pointing to a model weight binary.
gather_dtensor (bool): whether to gather the distributed tensor to the first device.
use_safetensors (bool): whether to use safe tensors.
""" """
pass pass
@ -212,13 +252,13 @@ class CheckpointIO(ABC):
# ======================================================== # ========================================================
@abstractmethod @abstractmethod
def load_sharded_optimizer(self, optimizer: Optimizer, checkpoint: Path, prefix: str, size_per_shard: int): def load_sharded_optimizer(self, optimizer: Optimizer, index_file_path: str, prefix: str, size_per_shard: int):
""" """
Load optimizer from sharded checkpoint. Load optimizer from sharded checkpoint.
Args: Args:
optimizer (Optimizer): optimizer to be loaded. optimizer (Optimizer): optimizer to be loaded.
checkpoint (str): checkpoint path. It should be path to the .index.json file or a path to a directory which contains a .index.json file. index_file_path (str): checkpoint path. It should be path to the .index.json file or a path to a directory which contains a .index.json file.
prefix (str): prefix for the optimizer checkpoint. prefix (str): prefix for the optimizer checkpoint.
size_per_shard (int): size per shard in MB. size_per_shard (int): size per shard in MB.
""" """
@ -236,26 +276,29 @@ class CheckpointIO(ABC):
pass pass
@abstractmethod @abstractmethod
def save_sharded_optimizer(self, optimizer: Optimizer, checkpoint: Path, prefix: str, size_per_shard: int): def save_sharded_optimizer(self, optimizer: Optimizer, checkpoint: Path, gather_dtensor: bool, prefix: str,
size_per_shard: int):
""" """
Save optimizer to sharded checkpoint. Save optimizer to sharded checkpoint.
Args: Args:
optimizer (Optimizer): optimizer to be saved. optimizer (Optimizer): optimizer to be saved.
checkpoint (Path): checkpoint path. It should be a directory path. checkpoint (Path): checkpoint path. It should be a directory path.
gather_dtensor (bool): whether to gather the distributed tensor to the first device.
prefix (str): prefix for the optimizer checkpoint. prefix (str): prefix for the optimizer checkpoint.
size_per_shard (int): size per shard in MB. size_per_shard (int): size per shard in MB.
""" """
pass pass
@abstractmethod @abstractmethod
def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path): def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path, gather_dtensor: bool):
""" """
Save optimizer to unsharded checkpoint. Save optimizer to unsharded checkpoint.
Args: Args:
optimizer (Optimizer): optimizer to be saved. optimizer (Optimizer): optimizer to be saved.
checkpoint (str): checkpoint path. It should be a single file path pointing to a model weight binary. checkpoint (str): checkpoint path. It should be a single file path pointing to a model weight binary.
gather_dtensor (bool): whether to gather the distributed tensor to the first device.
""" """
pass pass
@ -264,7 +307,6 @@ class CheckpointIO(ABC):
# as this is quite standard, there is no need # as this is quite standard, there is no need
# to make them abstract # to make them abstract
# ============================================ # ============================================
def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str): def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str):
""" """
Save lr scheduler to checkpoint. Save lr scheduler to checkpoint.
@ -285,231 +327,3 @@ class CheckpointIO(ABC):
""" """
state_dict = torch.load(checkpoint) state_dict = torch.load(checkpoint)
lr_scheduler.load_state_dict(state_dict) lr_scheduler.load_state_dict(state_dict)
# ========================================
# Helper functions for loading state dict
# ========================================
def get_sharded_checkpoint_index_file(self, checkpoint_path: Path):
"""
Get the index file path for a sharded checkpoint.
Args:
checkpoint_path (Path): path to the checkpoint.
Returns:
Path: path to the index file.
"""
if checkpoint_path.is_file():
# check if it is .index.json
if checkpoint_path.name.endswith('.index.json'):
return checkpoint_path
else:
raise ValueError(f'Invalid checkpoint path: {checkpoint_path}. ')
elif checkpoint_path.is_dir():
# check if there is only one a file ending with .index.json in this directory
index_files = list(checkpoint_path.glob('*.index.json'))
if len(index_files) == 1:
return index_files[0]
else:
raise ValueError(f'Found {len(index_files)} index files in {checkpoint_path}. ')
def is_sharded_checkpoint(self, checkpoint_path: Path):
"""
Check whether the checkpoint is sharded.
Args:
checkpoint (str): checkpoint path.
Returns:
bool: whether the checkpoint is sharded.
"""
if checkpoint_path.is_file():
# check if it is .index.json
if checkpoint_path.name.endswith('.index.json'):
return True
else:
return False
elif checkpoint_path.is_dir():
# check if there is only one a file ending with .index.json in this directory
index_files = list(checkpoint_path.glob('*.index.json'))
if len(index_files) == 1:
return True
else:
raise ValueError(f'Found {len(index_files)} index files in {checkpoint_path}. ')
def get_checkpoint_shard_filenames(self, index_file_path: Path):
"""
Get checkpoint shard filenames from a json file.
Args:
index_file_path (Path): path to the json file.
Returns:
list: checkpoint shard filenames.
"""
with open(str(index_file_path), 'r') as f:
shard_filenames = json.load(f)
if "weight_map" in index:
index = index["weight_map"]
checkpoint_root_path = index_file_path.absolute().parent
# read the checkpoint file list from the json file and get a list of unique file names
checkpoint_files = sorted(list(set(index.values())))
# get the absolute paths for all checkpoint files
checkpoint_files = [checkpoint_root_path.joinpath(f) for f in checkpoint_files]
return shard_filenames
def load_safetensors_state_dict(self, *args, **kwargs):
"""
Load safetensors state dict from checkpoint.
"""
# TODO(FrankLeeeee): support huggingface safetensors
raise NotImplementedError("This method is not implemented to support safe tensors")
def load_state_dict(self, checkpoint_file_path: Path):
"""
Load state dict from checkpoint.
Args:
checkpoint_file_path (Path): path to the checkpoint file.
Returns:
dict: state dict.
"""
return torch.load(str(checkpoint_file_path))
# ======================================
# Helper functions for saving state dict
# ======================================
def save_safetensors_state_dict(self, *args, **kwargs):
"""
Save safetensors state dict to checkpoint.
"""
# TODO(FrankLeeeee): support huggingface safetensors
raise NotImplementedError("This method is not implemented to support safe tensors")
def generate_checkpoint_shard_file_name(self, index: int, total_number: int, prefix: str = None):
"""
Generate checkpoint shard file name.
Args:
index (int): index of the shard.
total_number (int): total number of shards.
prefix (str): prefix of the shard file name. Default: None.
"""
if prefix is None:
return f"{index}-of-{total_number}.bin"
else:
return f"{prefix}-{index}-of-{total_number}.bin"
def save_checkpoint(self, state_dict: dict, checkpoint_file_path: Path):
"""
Save state dict to checkpoint.
Args:
state_dict (dict): state dict.
checkpoint_file_path (Path): path to the checkpoint file.
"""
torch.save(state_dict, str(checkpoint_file_path))
def save_state_dict_as_shard(self, state_dict: dict, index: int, total_number: int, prefix: str,
checkpoint_path: Path):
"""
Save state dict as shard.
Args:
state_dict (dict): state dict.
checkpoint_path (Path): path to the checkpoint file.
"""
# generate the shard name
shard_file_name = self.generate_checkpoint_shard_file_name(index, total_number, prefix)
shard_file_path = checkpoint_path.joinpath(shard_file_name)
# save the shard
self.save_checkpoint(state_dict, shard_file_path)
def calculate_param_size(self, param: torch.Tensor):
"""
Calculate the size of a parameter in MB. Used to compute whether a group of params exceed the shard size.
If so, a new shard should be created.
ArgsL
param (torch.Tensor): parameter tensor.
"""
# TODO(FrankLeeeee): check if this tensor is a DTensor, compute its global size if so
return param.numel() * param.element_size() / 1024 / 1024
class ShardCheckpointIndexFile:
"""
This class is a data structure to keep the content in the index.json file for sharded checkpoint.
Example:
>>> index = ShardCheckpointIndexFile()
>>> index.load('index.json')
>>> index.append_metadata('model_type', 'bert')
>>> index.append_weight_map('bert.embeddings.word_embeddings.weight', 'bert.embeddings.word_embeddings.weight-0-of-2.bin')
>>> index.export('index.json')
"""
def __init__(self) -> None:
self.metadata: dict = dict()
self.weight_map: dict = dict()
def load(self, json_path: str):
"""
Load the index file from a json file.
Args:
json_path (str): path to the json file.
"""
# load the json file
with open(json_path, 'r') as f:
index = json.load(f)
# assign attributes if exists
if "metadata" in index:
self.metadata = index["metadata"]
if "weight_map" in index:
self.weight_map = index["weight_map"]
def export(self, json_path: str):
"""
Export the index file to a json file.
Args:
json_path (str): path to the json file.
"""
# create the index file
index = dict()
index["metadata"] = self.metadata
index["weight_map"] = self.weight_map
# export the index file
with open(json_path, 'w') as f:
json.dump(index, f, indent=4)
def append_weight_map(self, param_name: str, shard_file: str):
"""
Append a weight map entry to the index file.
Args:
param_name (str): name of the parameter.
shard_file (str): name of the shard file.
"""
self.weight_map[param_name] = shard_file
def append_meta_data(self, name: str, val: Any):
"""
Append a metadata entry to the index file.
Args:
name (str): name of the metadata.
val (Any): value of the metadata.
"""
self.metadata[name] = val

View File

@ -4,42 +4,67 @@ import torch.nn as nn
from torch.optim import Optimizer from torch.optim import Optimizer
from .checkpoint_io_base import CheckpointIO from .checkpoint_io_base import CheckpointIO
from .index_file import CheckpointIndexFile
from .utils import has_index_file, load_state_dict, save_state_dict
__all__ = ['GeneralCheckpointIO'] __all__ = ['GeneralCheckpointIO']
class GeneralCheckpointIO(CheckpointIO): class GeneralCheckpointIO(CheckpointIO):
def load_sharded_model(self, model: nn.Module, checkpoint: Path, strict: bool): def load_sharded_model(self, model: nn.Module, index_file_path: Path, strict: bool):
index_file_path = self.get_sharded_checkpoint_index_file(checkpoint) # load the index file
index_file = CheckpointIndexFile.from_file(index_file_path)
# iterate over the shard checkpoint files # iterate over the shard checkpoint files
# and load each # and load each
shard_files = self.get_checkpoint_shard_filenames(index_file_path) index_file.assert_no_dtensor_checkpoint()
for shard_file in shard_files: checkpoint_file_list, _ = index_file.get_checkpoint_fileanames()
shard_checkpoint = self.load_state_dict(shard_file) for shard_file in checkpoint_file_list:
shard_checkpoint = load_state_dict(shard_file)
model.load_state_dict(shard_checkpoint, strict=strict) model.load_state_dict(shard_checkpoint, strict=strict)
def load_unsharded_model(self, model: nn.Module, checkpoint: Path, strict: bool): def load_unsharded_model(self, model: nn.Module, checkpoint: str, strict: bool):
checkpoint = self.load_state_dict(str(checkpoint)) checkpoint = load_state_dict(checkpoint)
model.load_state_dict(checkpoint, strict=strict) model.load_state_dict(checkpoint, strict=strict)
def save_sharded_model(self, model: nn.Module, checkpoint: Path, prefix: str, size_per_shard: int): def save_sharded_model(self, model: nn.Module, checkpoint: Path, gather_dtensor: bool, prefix: str,
size_per_shard: int, use_safetensors: bool):
# TODO(FrankLeeeee): implement this method as it can be supported by Huggingface model # TODO(FrankLeeeee): implement this method as it can be supported by Huggingface model
raise NotImplementedError("Sharded model checkpoint is not supported yet.") raise NotImplementedError("Sharded model checkpoint is not supported yet.")
def save_unsharded_model(self, model: nn.Module, checkpoint: Path): def save_unsharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, use_safetensors: bool):
self.save_checkpoint(model.state_dict(), checkpoint) state_dict = model.state_dict()
# TODO(FrankLeeeee): add support for gather_dtensor
if gather_dtensor:
pass
# save the checkpoint
save_state_dict(state_dict, checkpoint, use_safetensors)
def load_sharded_optimizer(self, optimizer: Optimizer, checkpoint: Path, prefix: str, size_per_shard: int): def load_sharded_optimizer(self, optimizer: Optimizer, checkpoint: Path, prefix: str, size_per_shard: int):
raise NotImplementedError("Sharded optimizer checkpoint is not supported yet.") raise NotImplementedError("Sharded optimizer checkpoint is not supported yet.")
def load_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path): def load_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path):
checkpoint = self.load_state_dict(checkpoint) checkpoint = load_state_dict(checkpoint)
optimizer.load_state_dict(checkpoint) optimizer.load_state_dict(checkpoint)
def save_sharded_optimizer(self, optimizer: Optimizer, checkpoint: Path, prefix: str, size_per_shard: int): def save_sharded_optimizer(
self,
optimizer: Optimizer,
checkpoint: Path,
gather_dtensor: bool,
prefix: str,
size_per_shard: int,
):
raise NotImplementedError("Sharded optimizer checkpoint is not supported yet.") raise NotImplementedError("Sharded optimizer checkpoint is not supported yet.")
def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path): def save_unsharded_optimizer(
self.save_checkpoint(optimizer.state_dict(), checkpoint) self,
optimizer: Optimizer,
checkpoint: Path,
gather_dtensor: bool,
):
# TODO(FrankLeeeee): handle distributed tensors
save_state_dict(optimizer.state_dict(), checkpoint, use_safetensors=False)

View File

@ -0,0 +1,150 @@
import json
from pathlib import Path
from typing import Any, List, Union
from .utils import is_dtensor_checkpoint
__all__ = ['CheckpointIndexFile']
class CheckpointIndexFile:
"""
This class is a data structure to keep the content in the index.json file for sharded checkpoint.
Example:
>>> index = CheckpointIndexFile.from_file('model.index.json')
>>> index.append_metadata('model_type', 'bert')
>>> index.append_weight_map('bert.embeddings.word_embeddings.weight', 'model_0001-of-0002.bin')
>>> index.export('new_index.json')
"""
def __init__(self) -> None:
self.root_path = None
self.metadata: dict = dict()
self.weight_map: dict = dict()
@staticmethod
def from_file(index_path: Union[str, Path]):
"""
Create a CheckpointIndexFile object from a json file.
Args:
index_path (str): path to the json file.
Returns:
CheckpointIndexFile: CheckpointIndexFile object.
"""
index = CheckpointIndexFile()
index.load(index_path)
return index
def load(self, json_path: str):
"""
Load the index file from a json file.
Args:
json_path (str): path to the json file.
"""
# load the json file
with open(json_path, 'r') as f:
index = json.load(f)
# assign attributes if exists
if "metadata" in index:
self.metadata = index["metadata"]
if "weight_map" in index:
self.weight_map = index["weight_map"]
# assign the root directory for the index file
self.root_path = Path(json_path).absolute().parent
def export(self, json_path: str):
"""
Export the index file to a json file.
Args:
json_path (str): path to the json file.
"""
# create the index file
index = dict()
index["metadata"] = self.metadata
index["weight_map"] = self.weight_map
# export the index file
with open(json_path, 'w') as f:
json.dump(index, f, indent=4)
def append_weight_map(self, param_name: str, shard_file: str):
"""
Append a weight map entry to the index file.
Args:
param_name (str): name of the parameter.
shard_file (str): name of the shard file.
"""
self.weight_map[param_name] = shard_file
def append_meta_data(self, name: str, val: Any):
"""
Append a metadata entry to the index file.
Args:
name (str): name of the metadata.
val (Any): value of the metadata.
"""
self.metadata[name] = val
def contains_dtensor(self):
"""
Check if the index file contains any distributed tensor. The distributed tensors will be stored in
`dtensor/module.linear.weight.*.bin` or `dtensor/module.linear.weight.*.safetensors` in the weight map.
Returns:
bool: True if the index file contains any distributed tensor, False otherwise.
"""
for value in self.weight_map.values():
if value.endswith(".*.bin") or value.endswith(".*.safetensors"):
return True
return False
def get_checkpoint_fileanames(self) -> List[str]:
"""
Get the set of checkpoint filenames in the weight map.
Returns:
list: checkpoint shard filenames.
"""
# read the checkpoint file list from the json file and get a list of unique file names
checkpoint_files = sorted(list(set(self.weight_map.values())))
# get the absolute paths for all checkpoint files
checkpoint_files = [str(self.root_path.joinpath(f)) for f in checkpoint_files]
dtensor_list = []
checkpoint_list = []
for ckpt_file in checkpoint_files:
if is_dtensor_checkpoint(ckpt_file):
dtensor_list.append(ckpt_file)
else:
checkpoint_list.append(ckpt_file)
return checkpoint_list, dtensor_list
def assert_no_dtensor_checkpoint(self):
for val in self.weight_map.values():
if is_dtensor_checkpoint(val):
raise ValueError(f"Checkpoint file {val} contains distributed tensor")
def get_checkpoint_file(self, param_name: str) -> str:
"""
Get the checkpoint file name for a parameter.
Args:
param_name (str): name of the parameter.
Returns:
str: checkpoint file name.
"""
ckpt_path = self.weight_map[param_name]
return ckpt_path

View File

@ -0,0 +1,278 @@
from pathlib import Path
from typing import List, Optional, Tuple
import torch
# ======================================
# General helper functions
# ======================================
def calculate_tensor_size(tensor: torch.Tensor) -> float:
"""
Calculate the size of a parameter in MB. Used to compute whether a group of params exceed the shard size.
If so, a new shard should be created.
Args:
tenosr (torch.Tensor): the tensor to calculate size for.
Returns:
float: size of the tensor in MB.
"""
return tensor.numel() * tensor.element_size() / 1024 / 1024
def is_safetensors_available() -> bool:
"""
Check whether safetensors is available.
Returns:
bool: whether safetensors is available.
"""
try:
import safetensors
return True
except ImportError:
return False
def is_dtensor_checkpoint(checkpoint_file_path: str) -> bool:
"""
Check whether the checkpoint file is a dtensor checkpoint.
Args:
checkpoint_file_path (str): path to the checkpoint file.
Returns:
bool: whether the checkpoint file is a dtensor checkpoint.
"""
if checkpoint_file_path.endswith('.*.safetensors') or checkpoint_file_path.endswith('.*.bin'):
return True
else:
return False
def is_safetensor_checkpoint(checkpoint_file_path: str) -> bool:
"""
Check whether the checkpoint file is a safetensor checkpoint.
Args:
checkpoint_file_path (str): path to the checkpoint file.
Returns:
bool: whether the checkpoint file is a safetensor checkpoint.
"""
if checkpoint_file_path.endswith('.safetensors'):
return True
else:
return False
# ======================================
# Helper functions for saving state dict
# ======================================
def save_state_dict(state_dict: dict, checkpoint_file_path: str, use_safetensors: bool) -> None:
"""
Save state dict to checkpoint.
Args:
state_dict (dict): state dict.
checkpoint_file_path (str): path to the checkpoint file.
use_safetensors (bool): whether to use safetensors to save the checkpoint.
"""
if use_safetensors:
assert is_safetensors_available(), "safetensors is not available."
assert checkpoint_file_path.endswith('.safetensors'), \
"safetensors only supports .safetensors suffix for checkpoint file."
from safetensors.torch import save_file
save_file(state_dict, checkpoint_file_path)
else:
torch.save(state_dict, checkpoint_file_path)
def save_dtensor(name: str, tensor: torch.Tensor, index_file: "CheckpointIndexFile", use_safetensors: bool) -> None:
"""
Save distributed tensor to checkpoint. This checkpoint will be a dictionary which contains
only one tensor.
Args:
tensor (Tensor): tensor to be saved.
index_file (CheckpointIndexFile): path to the checkpoint file.
size_per_shard (int): size per shard in MB.
"""
root_path = index_file.root_path
output_root_path = root_path.joinpath('dtensor')
# create directory
output_root_path.mkdir(exist_ok=True)
# save tensor to this directory
# TODO(YuliangLiu): get index of the tensor shard
# e.g. index =
index = 0
# save tensor to file
ckpt_file_name = generate_dtensor_file_name(name, index, use_safetensors)
ckpt_file_path = output_root_path.joinpath(ckpt_file_name)
# dtensor ckpt file always contains only one tensor
state_dict = {name: tensor}
save_state_dict(state_dict, str(ckpt_file_path), use_safetensors)
# update the weight map
# * means all shards
ckpt_file_name_in_weight_map = 'dtensor/' + generate_dtensor_file_name(name, '*', use_safetensors)
index_file.append_weight_map(name, ckpt_file_name_in_weight_map)
def get_checkpoint_file_suffix(use_safetensors: bool) -> str:
"""
Get checkpoint file suffix.
Args:
use_safetensors (bool): whether to use safetensors to save the checkpoint.
Returns:
str: checkpoint file suffix.
"""
if use_safetensors:
return '.safetensors'
else:
return '.bin'
def generate_checkpoint_shard_file_name(index: int,
total_number: int,
use_safetensors: bool,
prefix: str = None) -> str:
"""
Generate checkpoint shard file name.
Args:
index (int): index of the shard.
total_number (int): total number of shards.
use_safetensors (bool): whether to use safetensors to save the checkpoint.
prefix (str): prefix of the shard file name. Default: None.
Returns:
str: checkpoint shard file name.
"""
suffix = get_checkpoint_file_suffix(use_safetensors)
if prefix is None:
return f"{index:05d}-of-{total_number:05d}.{suffix}"
else:
return f"{prefix}-{index:05d}-of-{total_number:05d}.{suffix}"
def generate_dtensor_file_name(param_name: str, index: int, use_safetensors: bool) -> str:
"""
Generate dtensor file name.
Args:
param_name (str): name of the distributed parameter.
index (int): index of the shard.
use_safetensors (bool): whether to use safetensors to save the checkpoint.
Returns:
str: dtensor file name.
"""
suffix = get_checkpoint_file_suffix(use_safetensors)
return f'{param_name}.{index}.{suffix}'
def save_state_dict_as_shard(
state_dict: dict,
checkpoint_path: str,
index: int,
total_number: int,
use_safetensors: bool,
prefix: str = None,
) -> None:
"""
Save state dict as shard.
Args:
state_dict (dict): state dict.
checkpoint_path (str): path to the checkpoint file.
index (int): index of the shard.
total_number (int): total number of shards.
prefix (str): prefix of the shard file name.
use_safetensors (bool): whether to use safetensors to save the checkpoint.
"""
# generate the shard name
shard_file_name = generate_checkpoint_shard_file_name(index, total_number, use_safetensors, prefix)
shard_file_path = Path(checkpoint_path).joinpath(shard_file_name).absolute()
# save the shard
save_state_dict(state_dict, str(shard_file_path), use_safetensors)
# ========================================
# Helper functions for loading state dict
# ========================================
def has_index_file(checkpoint_path: str) -> Tuple[bool, Optional[Path]]:
"""
Check whether the checkpoint has an index file.
Args:
checkpoint_path (str): path to the checkpoint.
Returns:
Tuple[bool, Optional[Path]]: a tuple of (has_index_file, index_file_path)
"""
checkpoint_path = Path(checkpoint_path)
if checkpoint_path.is_file():
# check if it is .index.json
if checkpoint_path.name.endswith('.index.json'):
return True, checkpoint_path
else:
return False, None
elif checkpoint_path.is_dir():
# check if there is only one a file ending with .index.json in this directory
index_files = list(checkpoint_path.glob('*.index.json'))
# if we found a .index.json file, make sure there is only one
if len(index_files) > 0:
assert len(
index_files
) == 1, f'Expected to find one .index.json file in {checkpoint_path}, but found {len(index_files)}'
if len(index_files) == 1:
return True, index_files[0]
else:
return False, None
def load_state_dict(checkpoint_file_path: Path):
"""
Load state dict from checkpoint.
Args:
checkpoint_file_path (Path): path to the checkpoint file.
Returns:
dict: state dict.
"""
assert not is_dtensor_checkpoint(checkpoint_file_path), \
f'Cannot load state dict from dtensor checkpoint {checkpoint_file_path}, you should convert the distributed tensors to gathered tensors with our CLI offline.'
if is_safetensor_checkpoint(checkpoint_file_path):
assert is_safetensors_available(), \
f'Cannot load state dict from safetensor checkpoint {checkpoint_file_path}, because safetensors is not available. Please install safetensors first with pip install safetensors.'
# load with safetensors
from safetensors import safe_open
state_dict = {}
with safe_open(checkpoint_file_path, framework="pt", device="cpu") as f:
for k in f.keys():
state_dict[k] = f.get_tensor(k)
return state_dict
else:
# load with torch
return torch.load(checkpoint_file_path)

View File

@ -9,3 +9,4 @@ fabric
contexttimer contexttimer
ninja ninja
torch>=1.11 torch>=1.11
safetensors

View File

@ -71,6 +71,29 @@ def check_dataloader_sharding():
batch_to_compare), 'Same number was found across ranks but expected it to be different' batch_to_compare), 'Same number was found across ranks but expected it to be different'
def check_checkpoint_save_and_load():
model_fn, data_gen_fn, output_transform_fn, _ = model_zoo['timm_resnet']
plugin = TorchDDPPlugin()
booster = Booster(plugin=plugin)
model = model_fn()
optimizer = SGD(model.parameters(), lr=1e-3)
criterion = lambda x: x.mean()
data = data_gen_fn()
data = {k: v.to('cuda') if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__ else v for k, v in data.items()}
model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion)
output = model(**data)
output = output_transform_fn(output)
output_key = list(output.keys())[0]
loss = criterion(output[output_key])
booster.backward(loss, optimizer)
def run_dist(rank, world_size, port): def run_dist(rank, world_size, port):
# init dist env # init dist env
colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost') colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost')

View File

@ -1,5 +1,6 @@
import tempfile import tempfile
import pytest
import torch import torch
from torch.optim import Adam from torch.optim import Adam
from torchvision.models import resnet18 from torchvision.models import resnet18
@ -14,7 +15,8 @@ from colossalai.checkpoint_io import GeneralCheckpointIO
# ======== # ========
def test_unsharded_checkpoint(): @pytest.mark.parametrize('use_safetensors', [True, False])
def test_unsharded_checkpoint(use_safetensors: bool):
# create a model and optimizer # create a model and optimizer
model = resnet18() model = resnet18()
optimizer = Adam(model.parameters(), lr=0.001) optimizer = Adam(model.parameters(), lr=0.001)
@ -29,12 +31,16 @@ def test_unsharded_checkpoint():
optimizer.step() optimizer.step()
# create a temp file for checkpoint # create a temp file for checkpoint
model_ckpt_tempfile = tempfile.NamedTemporaryFile() if use_safetensors:
suffix = ".safetensors"
else:
suffix = ".bin"
model_ckpt_tempfile = tempfile.NamedTemporaryFile(suffix=suffix)
optimizer_ckpt_tempfile = tempfile.NamedTemporaryFile() optimizer_ckpt_tempfile = tempfile.NamedTemporaryFile()
# save the model and optimizer # save the model and optimizer
ckpt_io = GeneralCheckpointIO() ckpt_io = GeneralCheckpointIO()
ckpt_io.save_model(model, model_ckpt_tempfile.name) ckpt_io.save_model(model, model_ckpt_tempfile.name, use_safetensors=use_safetensors)
ckpt_io.save_optimizer(optimizer, optimizer_ckpt_tempfile.name) ckpt_io.save_optimizer(optimizer, optimizer_ckpt_tempfile.name)
# create new model # create new model
@ -68,3 +74,4 @@ def test_unsharded_checkpoint():
# check for model and optimizer state dict recursively # check for model and optimizer state dict recursively
recursive_check(model.state_dict(), new_model.state_dict()) recursive_check(model.state_dict(), new_model.state_dict())
recursive_check(optimizer.state_dict(), new_optimizer.state_dict()) recursive_check(optimizer.state_dict(), new_optimizer.state_dict())
recursive_check(optimizer.state_dict(), new_optimizer.state_dict())