mirror of https://github.com/hpcaitech/ColossalAI
[checkpoint] refactored the API and added safetensors support (#3427)
* [checkpoint] refactored the API and added safetensors support * polish codepull/3442/head
parent
26b7aac0be
commit
1beb85cc25
|
@ -33,7 +33,7 @@ class TorchDDPCheckpointIO(GeneralCheckpointIO):
|
||||||
# the model should be unwrapped in self.load_model via ModelWrapper.unwrap
|
# the model should be unwrapped in self.load_model via ModelWrapper.unwrap
|
||||||
return super().load_unsharded_model(model, checkpoint, strict=strict)
|
return super().load_unsharded_model(model, checkpoint, strict=strict)
|
||||||
|
|
||||||
def save_unsharded_model(self, model: nn.Module, checkpoint: str):
|
def save_unsharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool):
|
||||||
"""
|
"""
|
||||||
Save model to checkpoint but only on master process.
|
Save model to checkpoint but only on master process.
|
||||||
"""
|
"""
|
||||||
|
@ -41,7 +41,7 @@ class TorchDDPCheckpointIO(GeneralCheckpointIO):
|
||||||
if self.coordinator.is_master():
|
if self.coordinator.is_master():
|
||||||
super().save_unsharded_model(model, checkpoint)
|
super().save_unsharded_model(model, checkpoint)
|
||||||
|
|
||||||
def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: str):
|
def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool):
|
||||||
"""
|
"""
|
||||||
Save optimizer to checkpoint but only on master process.
|
Save optimizer to checkpoint but only on master process.
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
from .checkpoint_io_base import CheckpointIO, ShardCheckpointIndexFile
|
from .checkpoint_io_base import CheckpointIO
|
||||||
from .general_checkpoint_io import GeneralCheckpointIO
|
from .general_checkpoint_io import GeneralCheckpointIO
|
||||||
|
from .index_file import CheckpointIndexFile
|
||||||
|
|
||||||
__all__ = ['CheckpointIO', 'ShardCheckpointIndexFile', 'GeneralCheckpointIO']
|
__all__ = ['CheckpointIO', 'CheckpointIndexFile', 'GeneralCheckpointIO']
|
||||||
|
|
|
@ -1,7 +1,6 @@
|
||||||
import json
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Union
|
from typing import Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
@ -10,7 +9,9 @@ from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
|
||||||
|
|
||||||
from colossalai.interface import ModelWrapper
|
from colossalai.interface import ModelWrapper
|
||||||
|
|
||||||
__all__ = ['CheckpointIO', 'ShardCheckpointIndexFile']
|
from .utils import has_index_file
|
||||||
|
|
||||||
|
__all__ = ['CheckpointIO']
|
||||||
|
|
||||||
|
|
||||||
class CheckpointIO(ABC):
|
class CheckpointIO(ABC):
|
||||||
|
@ -25,15 +26,31 @@ class CheckpointIO(ABC):
|
||||||
>>> # load model from checkpoint
|
>>> # load model from checkpoint
|
||||||
>>> model = checkpoint_io.load_model(model, 'model.pt')
|
>>> model = checkpoint_io.load_model(model, 'model.pt')
|
||||||
>>>
|
>>>
|
||||||
>>> # save model to checkpoint
|
>>> # save model to checkpoint, any distributed tensor is gathered by default
|
||||||
>>> checkpoint_io.save_model(model, 'model.pt')
|
>>> checkpoint_io.save_model(model, 'model.pt')
|
||||||
>>>
|
>>>
|
||||||
|
>>> # if the model contains distributed tensor, and you don't want to gather it
|
||||||
|
>>> # each rank will save its own shard of the distributed tensor
|
||||||
|
>>> checkpoint_io.save_model(model, 'model.pt', gather_dtensor=False)
|
||||||
|
>>>
|
||||||
>>> # save model to sharded checkpoints
|
>>> # save model to sharded checkpoints
|
||||||
>>> checkpoint_io.save_model(model, './checkpoints/', shard=True)
|
>>> checkpoint_io.save_model(model, './checkpoints/', shard=True)
|
||||||
>>>
|
>>>
|
||||||
|
>>> # save model to sharded and assume we don't want to gather distributed tensors
|
||||||
|
>>> checkpoint_io.save_model(model, './checkpoints/', shard=True, gather_dtensor=False)
|
||||||
|
>>>
|
||||||
|
>>> # Note:
|
||||||
|
>>> # 1. we don't support loading from distributed tensors, conversion from distributed tensors
|
||||||
|
>>> # checkpoints to full tensor checkpoint should be done offline via our CLI
|
||||||
|
>>> # 2. you don't have to specify whether the model is sharded or not when loading the model
|
||||||
|
>>> # as it will be automatically detected
|
||||||
|
>>>
|
||||||
>>> # load model from sharded checkpoints
|
>>> # load model from sharded checkpoints
|
||||||
>>> model = checkpoint_io.load_model(model, './checkpoints/')
|
>>> model = checkpoint_io.load_model(model, './checkpoints/')
|
||||||
>>>
|
>>>
|
||||||
|
>>> # load model from unsharded checkpoints
|
||||||
|
>>> model = checkpoint_io.load_model(model, './checkpoints/')
|
||||||
|
>>>
|
||||||
>>> # load optimizer from checkpoint
|
>>> # load optimizer from checkpoint
|
||||||
>>> optimizer = checkpoint_io.load_optimizer(optimizer, 'optimizer.pt')
|
>>> optimizer = checkpoint_io.load_optimizer(optimizer, 'optimizer.pt')
|
||||||
>>>
|
>>>
|
||||||
|
@ -58,21 +75,27 @@ class CheckpointIO(ABC):
|
||||||
1. a file path, e.g. 'model.pt'
|
1. a file path, e.g. 'model.pt'
|
||||||
2. a path to a json file which defines the index to the sharded checkpoint
|
2. a path to a json file which defines the index to the sharded checkpoint
|
||||||
3. a path to a folder containing a unique .index.json file for sharded checkpoint
|
3. a path to a folder containing a unique .index.json file for sharded checkpoint
|
||||||
|
Distributed tensors cannot be loaded directly unless gathered offline via our CLI.
|
||||||
strict (bool): whether to strictly enforce that the param name in
|
strict (bool): whether to strictly enforce that the param name in
|
||||||
the checkpoint match the keys returned by this module's.
|
the checkpoint match the keys returned by this module's.
|
||||||
"""
|
"""
|
||||||
|
# since we only support loaded sharded and unsharded weight format
|
||||||
|
# containing no distributed tensors, dtensor -> full tensor conversion
|
||||||
|
# should be done offline via our CLI
|
||||||
|
# the existence of index file means it is a sharded checkpoint
|
||||||
ckpt_path = Path(checkpoint)
|
ckpt_path = Path(checkpoint)
|
||||||
is_sharded = self.is_sharded_checkpoint(ckpt_path)
|
index_file_exists, index_file_path = has_index_file(checkpoint)
|
||||||
|
|
||||||
|
# return the origin model instead of the unwrapped model
|
||||||
origin_model = model
|
origin_model = model
|
||||||
|
|
||||||
if isinstance(model, ModelWrapper):
|
if isinstance(model, ModelWrapper):
|
||||||
model = model.unwrap()
|
model = model.unwrap()
|
||||||
|
|
||||||
if is_sharded:
|
if index_file_exists:
|
||||||
self.load_sharded_model(model, ckpt_path, strict)
|
self.load_sharded_model(model, index_file_path, strict)
|
||||||
else:
|
else:
|
||||||
self.load_unsharded_model(model, ckpt_path, strict)
|
self.load_unsharded_model(model, checkpoint, strict)
|
||||||
|
|
||||||
return origin_model
|
return origin_model
|
||||||
|
|
||||||
|
@ -80,8 +103,10 @@ class CheckpointIO(ABC):
|
||||||
model: Union[nn.Module, ModelWrapper],
|
model: Union[nn.Module, ModelWrapper],
|
||||||
checkpoint: str,
|
checkpoint: str,
|
||||||
shard: bool = False,
|
shard: bool = False,
|
||||||
|
gather_dtensor: bool = True,
|
||||||
prefix: str = None,
|
prefix: str = None,
|
||||||
size_per_shard: int = 1024):
|
size_per_shard: int = 1024,
|
||||||
|
use_safetensors: bool = False):
|
||||||
"""
|
"""
|
||||||
Save model to checkpoint.
|
Save model to checkpoint.
|
||||||
|
|
||||||
|
@ -103,17 +128,19 @@ class CheckpointIO(ABC):
|
||||||
shard (bool): whether to shard the checkpoint. Default: False. If set to True, the checkpoint will be sharded into
|
shard (bool): whether to shard the checkpoint. Default: False. If set to True, the checkpoint will be sharded into
|
||||||
multiple files. The model shards will be specificed by a `model.index.json` file. When shard = True, please ensure
|
multiple files. The model shards will be specificed by a `model.index.json` file. When shard = True, please ensure
|
||||||
that the checkpoint path is a directory path instead of a file path.
|
that the checkpoint path is a directory path instead of a file path.
|
||||||
|
gather_dtensor (bool): whether to gather the distributed tensor to the first device. Default: True.
|
||||||
prefix (str): prefix for the model checkpoint file name when shard=True. Default: None.
|
prefix (str): prefix for the model checkpoint file name when shard=True. Default: None.
|
||||||
size_per_shard (int): size per shard in MB. Default: 1024. This value is only used when shard = True.
|
size_per_shard (int): size per shard in MB. Default: 1024. This value is only used when shard = True.
|
||||||
|
use_safetensors (bool): whether to use safe tensors. Default: False. If set to True, the checkpoint will be saved
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if isinstance(model, ModelWrapper):
|
if isinstance(model, ModelWrapper):
|
||||||
model = model.unwrap()
|
model = model.unwrap()
|
||||||
|
|
||||||
if shard:
|
if shard:
|
||||||
self.save_sharded_model(model, checkpoint, prefix, size_per_shard)
|
self.save_sharded_model(model, checkpoint, gather_dtensor, prefix, size_per_shard, use_safetensors)
|
||||||
else:
|
else:
|
||||||
self.save_unsharded_model(model, checkpoint)
|
self.save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors)
|
||||||
|
|
||||||
def load_optimizer(self, optimizer: Optimizer, checkpoint: str):
|
def load_optimizer(self, optimizer: Optimizer, checkpoint: str):
|
||||||
"""
|
"""
|
||||||
|
@ -123,22 +150,27 @@ class CheckpointIO(ABC):
|
||||||
optimizer (Optimizer): optimizer to be loaded.
|
optimizer (Optimizer): optimizer to be loaded.
|
||||||
checkpoint (str): checkpoint path. This value is made compatiblity with the model checkpoints in the
|
checkpoint (str): checkpoint path. This value is made compatiblity with the model checkpoints in the
|
||||||
"""
|
"""
|
||||||
ckpt_path = Path(checkpoint)
|
index_file_exists, index_file_path = has_index_file(checkpoint)
|
||||||
is_sharded = self.is_sharded_checkpoint(ckpt_path)
|
|
||||||
|
|
||||||
if is_sharded:
|
if Path(checkpoint).is_dir() and not index_file_exists:
|
||||||
self.load_sharded_optimizer(optimizer, ckpt_path)
|
# if the checkpoint is a directory and there is no index file, raise error
|
||||||
|
raise ValueError(f'Cannot find index file in {checkpoint}')
|
||||||
|
|
||||||
|
if index_file_exists:
|
||||||
|
# the existence of index file means it is a sharded checkpoint
|
||||||
|
self.load_sharded_optimizer(optimizer, index_file_path)
|
||||||
else:
|
else:
|
||||||
self.load_unsharded_optimizer(optimizer, ckpt_path)
|
self.load_unsharded_optimizer(optimizer, checkpoint)
|
||||||
|
|
||||||
def save_optimizer(self,
|
def save_optimizer(self,
|
||||||
optimizer: Optimizer,
|
optimizer: Optimizer,
|
||||||
checkpoint: str,
|
checkpoint: str,
|
||||||
shard: bool = False,
|
shard: bool = False,
|
||||||
|
gather_dtensor=True,
|
||||||
prefix: str = None,
|
prefix: str = None,
|
||||||
size_per_shard: int = 1024):
|
size_per_shard: int = 1024):
|
||||||
"""
|
"""
|
||||||
Save optimizer to checkpoint.
|
Save optimizer to checkpoint. Optimizer states saving is not compatible with safetensors.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
optimizer (Optimizer): optimizer to be saved.
|
optimizer (Optimizer): optimizer to be saved.
|
||||||
|
@ -148,30 +180,33 @@ class CheckpointIO(ABC):
|
||||||
3. a path to a folder containing a unique .index.json file for sharded checkpoint
|
3. a path to a folder containing a unique .index.json file for sharded checkpoint
|
||||||
shard (bool): whether to shard the checkpoint. Default: False. If set to True, the checkpoint will be sharded into
|
shard (bool): whether to shard the checkpoint. Default: False. If set to True, the checkpoint will be sharded into
|
||||||
multiple files. The optimizer shards will be specificed by a `optimizer.index.json` file.
|
multiple files. The optimizer shards will be specificed by a `optimizer.index.json` file.
|
||||||
|
gather_dtensor (bool): whether to gather the distributed tensor to the first device. Default: True.
|
||||||
prefix (str): prefix for the optimizer checkpoint when shard = True. Default: None.
|
prefix (str): prefix for the optimizer checkpoint when shard = True. Default: None.
|
||||||
size_per_shard (int): size per shard in MB. Default: 1024. This value is only used when shard is set to True.
|
size_per_shard (int): size per shard in MB. Default: 1024. This value is only used when shard is set to True.
|
||||||
"""
|
"""
|
||||||
if shard:
|
if shard:
|
||||||
self.save_sharded_optimizer(optimizer, checkpoint, prefix, size_per_shard)
|
self.save_sharded_optimizer(optimizer, checkpoint, gather_dtensor, prefix, size_per_shard)
|
||||||
else:
|
else:
|
||||||
self.save_unsharded_optimizer(optimizer, checkpoint)
|
self.save_unsharded_optimizer(optimizer, checkpoint, gather_dtensor)
|
||||||
|
|
||||||
# ========================================================
|
# ========================================================
|
||||||
# Abstract methods for model loading/saving implementation
|
# Abstract methods for model loading/saving implementation
|
||||||
# ========================================================
|
# ========================================================
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def load_sharded_model(self, model: nn.Module, checkpoint: Path, strict: bool):
|
def load_sharded_model(self, model: nn.Module, index_file_path: str, strict: bool):
|
||||||
"""
|
"""
|
||||||
Load model from sharded checkpoint.
|
Load model from sharded checkpoint.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model (nn.Module): model to be loaded.
|
model (nn.Module): model to be loaded.
|
||||||
checkpoint (str): checkpoint path. It should be path to the .index.json file or a path to a directory which contains a .index.json file.
|
index_file_path (str): checkpoint path. It should be path to the .index.json file or a path to a directory which contains a .index.json file.
|
||||||
|
strict (bool): whether to strictly enforce that the param name in
|
||||||
|
the checkpoint match the keys returned by this module's.
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def load_unsharded_model(self, model: nn.Module, checkpoint: Path, strict: bool):
|
def load_unsharded_model(self, model: nn.Module, checkpoint: str, strict: bool):
|
||||||
"""
|
"""
|
||||||
Load model from unsharded checkpoint.
|
Load model from unsharded checkpoint.
|
||||||
|
|
||||||
|
@ -184,26 +219,31 @@ class CheckpointIO(ABC):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def save_sharded_model(self, model: nn.Module, checkpoint: Path, prefix: str, size_per_shard: int):
|
def save_sharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, prefix: str,
|
||||||
|
size_per_shard: int, use_safetensors: bool):
|
||||||
"""
|
"""
|
||||||
Save model to sharded checkpoint.
|
Save model to sharded checkpoint.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model (nn.Module): model to be saved.
|
model (nn.Module): model to be saved.
|
||||||
checkpoint (Path): checkpoint path. It should be a directory path.
|
checkpoint (str): checkpoint path. It should be a directory path.
|
||||||
|
gather_dtensor (bool): whether to gather the distributed tensor to the first device.
|
||||||
prefix (str): prefix for the model checkpoint.
|
prefix (str): prefix for the model checkpoint.
|
||||||
size_per_shard (int): size per shard in MB.
|
size_per_shard (int): size per shard in MB.
|
||||||
|
use_safetensors (bool): whether to use safe tensors.
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def save_unsharded_model(self, model: nn.Module, checkpoint: Path):
|
def save_unsharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, use_safetensors: bool):
|
||||||
"""
|
"""
|
||||||
Save model to unsharded checkpoint.
|
Save model to unsharded checkpoint.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model (nn.Module): model to be saved.
|
model (nn.Module): model to be saved.
|
||||||
checkpoint (Path): checkpoint path. It should be a single file path pointing to a model weight binary.
|
checkpoint (str): checkpoint path. It should be a single file path pointing to a model weight binary.
|
||||||
|
gather_dtensor (bool): whether to gather the distributed tensor to the first device.
|
||||||
|
use_safetensors (bool): whether to use safe tensors.
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@ -212,13 +252,13 @@ class CheckpointIO(ABC):
|
||||||
# ========================================================
|
# ========================================================
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def load_sharded_optimizer(self, optimizer: Optimizer, checkpoint: Path, prefix: str, size_per_shard: int):
|
def load_sharded_optimizer(self, optimizer: Optimizer, index_file_path: str, prefix: str, size_per_shard: int):
|
||||||
"""
|
"""
|
||||||
Load optimizer from sharded checkpoint.
|
Load optimizer from sharded checkpoint.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
optimizer (Optimizer): optimizer to be loaded.
|
optimizer (Optimizer): optimizer to be loaded.
|
||||||
checkpoint (str): checkpoint path. It should be path to the .index.json file or a path to a directory which contains a .index.json file.
|
index_file_path (str): checkpoint path. It should be path to the .index.json file or a path to a directory which contains a .index.json file.
|
||||||
prefix (str): prefix for the optimizer checkpoint.
|
prefix (str): prefix for the optimizer checkpoint.
|
||||||
size_per_shard (int): size per shard in MB.
|
size_per_shard (int): size per shard in MB.
|
||||||
"""
|
"""
|
||||||
|
@ -236,26 +276,29 @@ class CheckpointIO(ABC):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def save_sharded_optimizer(self, optimizer: Optimizer, checkpoint: Path, prefix: str, size_per_shard: int):
|
def save_sharded_optimizer(self, optimizer: Optimizer, checkpoint: Path, gather_dtensor: bool, prefix: str,
|
||||||
|
size_per_shard: int):
|
||||||
"""
|
"""
|
||||||
Save optimizer to sharded checkpoint.
|
Save optimizer to sharded checkpoint.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
optimizer (Optimizer): optimizer to be saved.
|
optimizer (Optimizer): optimizer to be saved.
|
||||||
checkpoint (Path): checkpoint path. It should be a directory path.
|
checkpoint (Path): checkpoint path. It should be a directory path.
|
||||||
|
gather_dtensor (bool): whether to gather the distributed tensor to the first device.
|
||||||
prefix (str): prefix for the optimizer checkpoint.
|
prefix (str): prefix for the optimizer checkpoint.
|
||||||
size_per_shard (int): size per shard in MB.
|
size_per_shard (int): size per shard in MB.
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path):
|
def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path, gather_dtensor: bool):
|
||||||
"""
|
"""
|
||||||
Save optimizer to unsharded checkpoint.
|
Save optimizer to unsharded checkpoint.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
optimizer (Optimizer): optimizer to be saved.
|
optimizer (Optimizer): optimizer to be saved.
|
||||||
checkpoint (str): checkpoint path. It should be a single file path pointing to a model weight binary.
|
checkpoint (str): checkpoint path. It should be a single file path pointing to a model weight binary.
|
||||||
|
gather_dtensor (bool): whether to gather the distributed tensor to the first device.
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@ -264,7 +307,6 @@ class CheckpointIO(ABC):
|
||||||
# as this is quite standard, there is no need
|
# as this is quite standard, there is no need
|
||||||
# to make them abstract
|
# to make them abstract
|
||||||
# ============================================
|
# ============================================
|
||||||
|
|
||||||
def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str):
|
def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str):
|
||||||
"""
|
"""
|
||||||
Save lr scheduler to checkpoint.
|
Save lr scheduler to checkpoint.
|
||||||
|
@ -285,231 +327,3 @@ class CheckpointIO(ABC):
|
||||||
"""
|
"""
|
||||||
state_dict = torch.load(checkpoint)
|
state_dict = torch.load(checkpoint)
|
||||||
lr_scheduler.load_state_dict(state_dict)
|
lr_scheduler.load_state_dict(state_dict)
|
||||||
|
|
||||||
# ========================================
|
|
||||||
# Helper functions for loading state dict
|
|
||||||
# ========================================
|
|
||||||
|
|
||||||
def get_sharded_checkpoint_index_file(self, checkpoint_path: Path):
|
|
||||||
"""
|
|
||||||
Get the index file path for a sharded checkpoint.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
checkpoint_path (Path): path to the checkpoint.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Path: path to the index file.
|
|
||||||
"""
|
|
||||||
if checkpoint_path.is_file():
|
|
||||||
# check if it is .index.json
|
|
||||||
if checkpoint_path.name.endswith('.index.json'):
|
|
||||||
return checkpoint_path
|
|
||||||
else:
|
|
||||||
raise ValueError(f'Invalid checkpoint path: {checkpoint_path}. ')
|
|
||||||
elif checkpoint_path.is_dir():
|
|
||||||
# check if there is only one a file ending with .index.json in this directory
|
|
||||||
index_files = list(checkpoint_path.glob('*.index.json'))
|
|
||||||
if len(index_files) == 1:
|
|
||||||
return index_files[0]
|
|
||||||
else:
|
|
||||||
raise ValueError(f'Found {len(index_files)} index files in {checkpoint_path}. ')
|
|
||||||
|
|
||||||
def is_sharded_checkpoint(self, checkpoint_path: Path):
|
|
||||||
"""
|
|
||||||
Check whether the checkpoint is sharded.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
checkpoint (str): checkpoint path.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: whether the checkpoint is sharded.
|
|
||||||
"""
|
|
||||||
if checkpoint_path.is_file():
|
|
||||||
# check if it is .index.json
|
|
||||||
if checkpoint_path.name.endswith('.index.json'):
|
|
||||||
return True
|
|
||||||
else:
|
|
||||||
return False
|
|
||||||
elif checkpoint_path.is_dir():
|
|
||||||
# check if there is only one a file ending with .index.json in this directory
|
|
||||||
index_files = list(checkpoint_path.glob('*.index.json'))
|
|
||||||
if len(index_files) == 1:
|
|
||||||
return True
|
|
||||||
else:
|
|
||||||
raise ValueError(f'Found {len(index_files)} index files in {checkpoint_path}. ')
|
|
||||||
|
|
||||||
def get_checkpoint_shard_filenames(self, index_file_path: Path):
|
|
||||||
"""
|
|
||||||
Get checkpoint shard filenames from a json file.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
index_file_path (Path): path to the json file.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
list: checkpoint shard filenames.
|
|
||||||
"""
|
|
||||||
with open(str(index_file_path), 'r') as f:
|
|
||||||
shard_filenames = json.load(f)
|
|
||||||
|
|
||||||
if "weight_map" in index:
|
|
||||||
index = index["weight_map"]
|
|
||||||
|
|
||||||
checkpoint_root_path = index_file_path.absolute().parent
|
|
||||||
|
|
||||||
# read the checkpoint file list from the json file and get a list of unique file names
|
|
||||||
checkpoint_files = sorted(list(set(index.values())))
|
|
||||||
|
|
||||||
# get the absolute paths for all checkpoint files
|
|
||||||
checkpoint_files = [checkpoint_root_path.joinpath(f) for f in checkpoint_files]
|
|
||||||
return shard_filenames
|
|
||||||
|
|
||||||
def load_safetensors_state_dict(self, *args, **kwargs):
|
|
||||||
"""
|
|
||||||
Load safetensors state dict from checkpoint.
|
|
||||||
"""
|
|
||||||
# TODO(FrankLeeeee): support huggingface safetensors
|
|
||||||
raise NotImplementedError("This method is not implemented to support safe tensors")
|
|
||||||
|
|
||||||
def load_state_dict(self, checkpoint_file_path: Path):
|
|
||||||
"""
|
|
||||||
Load state dict from checkpoint.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
checkpoint_file_path (Path): path to the checkpoint file.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
dict: state dict.
|
|
||||||
"""
|
|
||||||
return torch.load(str(checkpoint_file_path))
|
|
||||||
|
|
||||||
# ======================================
|
|
||||||
# Helper functions for saving state dict
|
|
||||||
# ======================================
|
|
||||||
|
|
||||||
def save_safetensors_state_dict(self, *args, **kwargs):
|
|
||||||
"""
|
|
||||||
Save safetensors state dict to checkpoint.
|
|
||||||
"""
|
|
||||||
# TODO(FrankLeeeee): support huggingface safetensors
|
|
||||||
raise NotImplementedError("This method is not implemented to support safe tensors")
|
|
||||||
|
|
||||||
def generate_checkpoint_shard_file_name(self, index: int, total_number: int, prefix: str = None):
|
|
||||||
"""
|
|
||||||
Generate checkpoint shard file name.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
index (int): index of the shard.
|
|
||||||
total_number (int): total number of shards.
|
|
||||||
prefix (str): prefix of the shard file name. Default: None.
|
|
||||||
"""
|
|
||||||
if prefix is None:
|
|
||||||
return f"{index}-of-{total_number}.bin"
|
|
||||||
else:
|
|
||||||
return f"{prefix}-{index}-of-{total_number}.bin"
|
|
||||||
|
|
||||||
def save_checkpoint(self, state_dict: dict, checkpoint_file_path: Path):
|
|
||||||
"""
|
|
||||||
Save state dict to checkpoint.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
state_dict (dict): state dict.
|
|
||||||
checkpoint_file_path (Path): path to the checkpoint file.
|
|
||||||
"""
|
|
||||||
torch.save(state_dict, str(checkpoint_file_path))
|
|
||||||
|
|
||||||
def save_state_dict_as_shard(self, state_dict: dict, index: int, total_number: int, prefix: str,
|
|
||||||
checkpoint_path: Path):
|
|
||||||
"""
|
|
||||||
Save state dict as shard.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
state_dict (dict): state dict.
|
|
||||||
checkpoint_path (Path): path to the checkpoint file.
|
|
||||||
"""
|
|
||||||
# generate the shard name
|
|
||||||
shard_file_name = self.generate_checkpoint_shard_file_name(index, total_number, prefix)
|
|
||||||
shard_file_path = checkpoint_path.joinpath(shard_file_name)
|
|
||||||
|
|
||||||
# save the shard
|
|
||||||
self.save_checkpoint(state_dict, shard_file_path)
|
|
||||||
|
|
||||||
def calculate_param_size(self, param: torch.Tensor):
|
|
||||||
"""
|
|
||||||
Calculate the size of a parameter in MB. Used to compute whether a group of params exceed the shard size.
|
|
||||||
If so, a new shard should be created.
|
|
||||||
|
|
||||||
ArgsL
|
|
||||||
param (torch.Tensor): parameter tensor.
|
|
||||||
"""
|
|
||||||
# TODO(FrankLeeeee): check if this tensor is a DTensor, compute its global size if so
|
|
||||||
return param.numel() * param.element_size() / 1024 / 1024
|
|
||||||
|
|
||||||
|
|
||||||
class ShardCheckpointIndexFile:
|
|
||||||
"""
|
|
||||||
This class is a data structure to keep the content in the index.json file for sharded checkpoint.
|
|
||||||
|
|
||||||
Example:
|
|
||||||
>>> index = ShardCheckpointIndexFile()
|
|
||||||
>>> index.load('index.json')
|
|
||||||
>>> index.append_metadata('model_type', 'bert')
|
|
||||||
>>> index.append_weight_map('bert.embeddings.word_embeddings.weight', 'bert.embeddings.word_embeddings.weight-0-of-2.bin')
|
|
||||||
>>> index.export('index.json')
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self) -> None:
|
|
||||||
self.metadata: dict = dict()
|
|
||||||
self.weight_map: dict = dict()
|
|
||||||
|
|
||||||
def load(self, json_path: str):
|
|
||||||
"""
|
|
||||||
Load the index file from a json file.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
json_path (str): path to the json file.
|
|
||||||
"""
|
|
||||||
# load the json file
|
|
||||||
with open(json_path, 'r') as f:
|
|
||||||
index = json.load(f)
|
|
||||||
|
|
||||||
# assign attributes if exists
|
|
||||||
if "metadata" in index:
|
|
||||||
self.metadata = index["metadata"]
|
|
||||||
if "weight_map" in index:
|
|
||||||
self.weight_map = index["weight_map"]
|
|
||||||
|
|
||||||
def export(self, json_path: str):
|
|
||||||
"""
|
|
||||||
Export the index file to a json file.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
json_path (str): path to the json file.
|
|
||||||
"""
|
|
||||||
# create the index file
|
|
||||||
index = dict()
|
|
||||||
index["metadata"] = self.metadata
|
|
||||||
index["weight_map"] = self.weight_map
|
|
||||||
|
|
||||||
# export the index file
|
|
||||||
with open(json_path, 'w') as f:
|
|
||||||
json.dump(index, f, indent=4)
|
|
||||||
|
|
||||||
def append_weight_map(self, param_name: str, shard_file: str):
|
|
||||||
"""
|
|
||||||
Append a weight map entry to the index file.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
param_name (str): name of the parameter.
|
|
||||||
shard_file (str): name of the shard file.
|
|
||||||
"""
|
|
||||||
self.weight_map[param_name] = shard_file
|
|
||||||
|
|
||||||
def append_meta_data(self, name: str, val: Any):
|
|
||||||
"""
|
|
||||||
Append a metadata entry to the index file.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
name (str): name of the metadata.
|
|
||||||
val (Any): value of the metadata.
|
|
||||||
"""
|
|
||||||
self.metadata[name] = val
|
|
||||||
|
|
|
@ -4,42 +4,67 @@ import torch.nn as nn
|
||||||
from torch.optim import Optimizer
|
from torch.optim import Optimizer
|
||||||
|
|
||||||
from .checkpoint_io_base import CheckpointIO
|
from .checkpoint_io_base import CheckpointIO
|
||||||
|
from .index_file import CheckpointIndexFile
|
||||||
|
from .utils import has_index_file, load_state_dict, save_state_dict
|
||||||
|
|
||||||
__all__ = ['GeneralCheckpointIO']
|
__all__ = ['GeneralCheckpointIO']
|
||||||
|
|
||||||
|
|
||||||
class GeneralCheckpointIO(CheckpointIO):
|
class GeneralCheckpointIO(CheckpointIO):
|
||||||
|
|
||||||
def load_sharded_model(self, model: nn.Module, checkpoint: Path, strict: bool):
|
def load_sharded_model(self, model: nn.Module, index_file_path: Path, strict: bool):
|
||||||
index_file_path = self.get_sharded_checkpoint_index_file(checkpoint)
|
# load the index file
|
||||||
|
index_file = CheckpointIndexFile.from_file(index_file_path)
|
||||||
|
|
||||||
# iterate over the shard checkpoint files
|
# iterate over the shard checkpoint files
|
||||||
# and load each
|
# and load each
|
||||||
shard_files = self.get_checkpoint_shard_filenames(index_file_path)
|
index_file.assert_no_dtensor_checkpoint()
|
||||||
for shard_file in shard_files:
|
checkpoint_file_list, _ = index_file.get_checkpoint_fileanames()
|
||||||
shard_checkpoint = self.load_state_dict(shard_file)
|
for shard_file in checkpoint_file_list:
|
||||||
|
shard_checkpoint = load_state_dict(shard_file)
|
||||||
model.load_state_dict(shard_checkpoint, strict=strict)
|
model.load_state_dict(shard_checkpoint, strict=strict)
|
||||||
|
|
||||||
def load_unsharded_model(self, model: nn.Module, checkpoint: Path, strict: bool):
|
def load_unsharded_model(self, model: nn.Module, checkpoint: str, strict: bool):
|
||||||
checkpoint = self.load_state_dict(str(checkpoint))
|
checkpoint = load_state_dict(checkpoint)
|
||||||
model.load_state_dict(checkpoint, strict=strict)
|
model.load_state_dict(checkpoint, strict=strict)
|
||||||
|
|
||||||
def save_sharded_model(self, model: nn.Module, checkpoint: Path, prefix: str, size_per_shard: int):
|
def save_sharded_model(self, model: nn.Module, checkpoint: Path, gather_dtensor: bool, prefix: str,
|
||||||
|
size_per_shard: int, use_safetensors: bool):
|
||||||
# TODO(FrankLeeeee): implement this method as it can be supported by Huggingface model
|
# TODO(FrankLeeeee): implement this method as it can be supported by Huggingface model
|
||||||
raise NotImplementedError("Sharded model checkpoint is not supported yet.")
|
raise NotImplementedError("Sharded model checkpoint is not supported yet.")
|
||||||
|
|
||||||
def save_unsharded_model(self, model: nn.Module, checkpoint: Path):
|
def save_unsharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, use_safetensors: bool):
|
||||||
self.save_checkpoint(model.state_dict(), checkpoint)
|
state_dict = model.state_dict()
|
||||||
|
|
||||||
|
# TODO(FrankLeeeee): add support for gather_dtensor
|
||||||
|
if gather_dtensor:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# save the checkpoint
|
||||||
|
save_state_dict(state_dict, checkpoint, use_safetensors)
|
||||||
|
|
||||||
def load_sharded_optimizer(self, optimizer: Optimizer, checkpoint: Path, prefix: str, size_per_shard: int):
|
def load_sharded_optimizer(self, optimizer: Optimizer, checkpoint: Path, prefix: str, size_per_shard: int):
|
||||||
raise NotImplementedError("Sharded optimizer checkpoint is not supported yet.")
|
raise NotImplementedError("Sharded optimizer checkpoint is not supported yet.")
|
||||||
|
|
||||||
def load_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path):
|
def load_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path):
|
||||||
checkpoint = self.load_state_dict(checkpoint)
|
checkpoint = load_state_dict(checkpoint)
|
||||||
optimizer.load_state_dict(checkpoint)
|
optimizer.load_state_dict(checkpoint)
|
||||||
|
|
||||||
def save_sharded_optimizer(self, optimizer: Optimizer, checkpoint: Path, prefix: str, size_per_shard: int):
|
def save_sharded_optimizer(
|
||||||
|
self,
|
||||||
|
optimizer: Optimizer,
|
||||||
|
checkpoint: Path,
|
||||||
|
gather_dtensor: bool,
|
||||||
|
prefix: str,
|
||||||
|
size_per_shard: int,
|
||||||
|
):
|
||||||
raise NotImplementedError("Sharded optimizer checkpoint is not supported yet.")
|
raise NotImplementedError("Sharded optimizer checkpoint is not supported yet.")
|
||||||
|
|
||||||
def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path):
|
def save_unsharded_optimizer(
|
||||||
self.save_checkpoint(optimizer.state_dict(), checkpoint)
|
self,
|
||||||
|
optimizer: Optimizer,
|
||||||
|
checkpoint: Path,
|
||||||
|
gather_dtensor: bool,
|
||||||
|
):
|
||||||
|
# TODO(FrankLeeeee): handle distributed tensors
|
||||||
|
save_state_dict(optimizer.state_dict(), checkpoint, use_safetensors=False)
|
||||||
|
|
|
@ -0,0 +1,150 @@
|
||||||
|
import json
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, List, Union
|
||||||
|
|
||||||
|
from .utils import is_dtensor_checkpoint
|
||||||
|
|
||||||
|
__all__ = ['CheckpointIndexFile']
|
||||||
|
|
||||||
|
|
||||||
|
class CheckpointIndexFile:
|
||||||
|
"""
|
||||||
|
This class is a data structure to keep the content in the index.json file for sharded checkpoint.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> index = CheckpointIndexFile.from_file('model.index.json')
|
||||||
|
>>> index.append_metadata('model_type', 'bert')
|
||||||
|
>>> index.append_weight_map('bert.embeddings.word_embeddings.weight', 'model_0001-of-0002.bin')
|
||||||
|
>>> index.export('new_index.json')
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.root_path = None
|
||||||
|
self.metadata: dict = dict()
|
||||||
|
self.weight_map: dict = dict()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_file(index_path: Union[str, Path]):
|
||||||
|
"""
|
||||||
|
Create a CheckpointIndexFile object from a json file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
index_path (str): path to the json file.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
CheckpointIndexFile: CheckpointIndexFile object.
|
||||||
|
"""
|
||||||
|
index = CheckpointIndexFile()
|
||||||
|
index.load(index_path)
|
||||||
|
return index
|
||||||
|
|
||||||
|
def load(self, json_path: str):
|
||||||
|
"""
|
||||||
|
Load the index file from a json file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
json_path (str): path to the json file.
|
||||||
|
"""
|
||||||
|
# load the json file
|
||||||
|
with open(json_path, 'r') as f:
|
||||||
|
index = json.load(f)
|
||||||
|
|
||||||
|
# assign attributes if exists
|
||||||
|
if "metadata" in index:
|
||||||
|
self.metadata = index["metadata"]
|
||||||
|
if "weight_map" in index:
|
||||||
|
self.weight_map = index["weight_map"]
|
||||||
|
|
||||||
|
# assign the root directory for the index file
|
||||||
|
self.root_path = Path(json_path).absolute().parent
|
||||||
|
|
||||||
|
def export(self, json_path: str):
|
||||||
|
"""
|
||||||
|
Export the index file to a json file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
json_path (str): path to the json file.
|
||||||
|
"""
|
||||||
|
# create the index file
|
||||||
|
index = dict()
|
||||||
|
index["metadata"] = self.metadata
|
||||||
|
index["weight_map"] = self.weight_map
|
||||||
|
|
||||||
|
# export the index file
|
||||||
|
with open(json_path, 'w') as f:
|
||||||
|
json.dump(index, f, indent=4)
|
||||||
|
|
||||||
|
def append_weight_map(self, param_name: str, shard_file: str):
|
||||||
|
"""
|
||||||
|
Append a weight map entry to the index file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
param_name (str): name of the parameter.
|
||||||
|
shard_file (str): name of the shard file.
|
||||||
|
"""
|
||||||
|
self.weight_map[param_name] = shard_file
|
||||||
|
|
||||||
|
def append_meta_data(self, name: str, val: Any):
|
||||||
|
"""
|
||||||
|
Append a metadata entry to the index file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name (str): name of the metadata.
|
||||||
|
val (Any): value of the metadata.
|
||||||
|
"""
|
||||||
|
self.metadata[name] = val
|
||||||
|
|
||||||
|
def contains_dtensor(self):
|
||||||
|
"""
|
||||||
|
Check if the index file contains any distributed tensor. The distributed tensors will be stored in
|
||||||
|
`dtensor/module.linear.weight.*.bin` or `dtensor/module.linear.weight.*.safetensors` in the weight map.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if the index file contains any distributed tensor, False otherwise.
|
||||||
|
"""
|
||||||
|
for value in self.weight_map.values():
|
||||||
|
if value.endswith(".*.bin") or value.endswith(".*.safetensors"):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def get_checkpoint_fileanames(self) -> List[str]:
|
||||||
|
"""
|
||||||
|
Get the set of checkpoint filenames in the weight map.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list: checkpoint shard filenames.
|
||||||
|
"""
|
||||||
|
# read the checkpoint file list from the json file and get a list of unique file names
|
||||||
|
checkpoint_files = sorted(list(set(self.weight_map.values())))
|
||||||
|
|
||||||
|
# get the absolute paths for all checkpoint files
|
||||||
|
checkpoint_files = [str(self.root_path.joinpath(f)) for f in checkpoint_files]
|
||||||
|
|
||||||
|
dtensor_list = []
|
||||||
|
checkpoint_list = []
|
||||||
|
|
||||||
|
for ckpt_file in checkpoint_files:
|
||||||
|
if is_dtensor_checkpoint(ckpt_file):
|
||||||
|
dtensor_list.append(ckpt_file)
|
||||||
|
else:
|
||||||
|
checkpoint_list.append(ckpt_file)
|
||||||
|
|
||||||
|
return checkpoint_list, dtensor_list
|
||||||
|
|
||||||
|
def assert_no_dtensor_checkpoint(self):
|
||||||
|
for val in self.weight_map.values():
|
||||||
|
if is_dtensor_checkpoint(val):
|
||||||
|
raise ValueError(f"Checkpoint file {val} contains distributed tensor")
|
||||||
|
|
||||||
|
def get_checkpoint_file(self, param_name: str) -> str:
|
||||||
|
"""
|
||||||
|
Get the checkpoint file name for a parameter.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
param_name (str): name of the parameter.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: checkpoint file name.
|
||||||
|
"""
|
||||||
|
ckpt_path = self.weight_map[param_name]
|
||||||
|
return ckpt_path
|
|
@ -0,0 +1,278 @@
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
# ======================================
|
||||||
|
# General helper functions
|
||||||
|
# ======================================
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_tensor_size(tensor: torch.Tensor) -> float:
|
||||||
|
"""
|
||||||
|
Calculate the size of a parameter in MB. Used to compute whether a group of params exceed the shard size.
|
||||||
|
If so, a new shard should be created.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tenosr (torch.Tensor): the tensor to calculate size for.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
float: size of the tensor in MB.
|
||||||
|
"""
|
||||||
|
return tensor.numel() * tensor.element_size() / 1024 / 1024
|
||||||
|
|
||||||
|
|
||||||
|
def is_safetensors_available() -> bool:
|
||||||
|
"""
|
||||||
|
Check whether safetensors is available.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: whether safetensors is available.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
import safetensors
|
||||||
|
return True
|
||||||
|
except ImportError:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def is_dtensor_checkpoint(checkpoint_file_path: str) -> bool:
|
||||||
|
"""
|
||||||
|
Check whether the checkpoint file is a dtensor checkpoint.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
checkpoint_file_path (str): path to the checkpoint file.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: whether the checkpoint file is a dtensor checkpoint.
|
||||||
|
"""
|
||||||
|
if checkpoint_file_path.endswith('.*.safetensors') or checkpoint_file_path.endswith('.*.bin'):
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def is_safetensor_checkpoint(checkpoint_file_path: str) -> bool:
|
||||||
|
"""
|
||||||
|
Check whether the checkpoint file is a safetensor checkpoint.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
checkpoint_file_path (str): path to the checkpoint file.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: whether the checkpoint file is a safetensor checkpoint.
|
||||||
|
"""
|
||||||
|
if checkpoint_file_path.endswith('.safetensors'):
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
# ======================================
|
||||||
|
# Helper functions for saving state dict
|
||||||
|
# ======================================
|
||||||
|
|
||||||
|
|
||||||
|
def save_state_dict(state_dict: dict, checkpoint_file_path: str, use_safetensors: bool) -> None:
|
||||||
|
"""
|
||||||
|
Save state dict to checkpoint.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state_dict (dict): state dict.
|
||||||
|
checkpoint_file_path (str): path to the checkpoint file.
|
||||||
|
use_safetensors (bool): whether to use safetensors to save the checkpoint.
|
||||||
|
"""
|
||||||
|
if use_safetensors:
|
||||||
|
assert is_safetensors_available(), "safetensors is not available."
|
||||||
|
assert checkpoint_file_path.endswith('.safetensors'), \
|
||||||
|
"safetensors only supports .safetensors suffix for checkpoint file."
|
||||||
|
from safetensors.torch import save_file
|
||||||
|
save_file(state_dict, checkpoint_file_path)
|
||||||
|
else:
|
||||||
|
torch.save(state_dict, checkpoint_file_path)
|
||||||
|
|
||||||
|
|
||||||
|
def save_dtensor(name: str, tensor: torch.Tensor, index_file: "CheckpointIndexFile", use_safetensors: bool) -> None:
|
||||||
|
"""
|
||||||
|
Save distributed tensor to checkpoint. This checkpoint will be a dictionary which contains
|
||||||
|
only one tensor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tensor (Tensor): tensor to be saved.
|
||||||
|
index_file (CheckpointIndexFile): path to the checkpoint file.
|
||||||
|
size_per_shard (int): size per shard in MB.
|
||||||
|
"""
|
||||||
|
root_path = index_file.root_path
|
||||||
|
output_root_path = root_path.joinpath('dtensor')
|
||||||
|
|
||||||
|
# create directory
|
||||||
|
output_root_path.mkdir(exist_ok=True)
|
||||||
|
|
||||||
|
# save tensor to this directory
|
||||||
|
# TODO(YuliangLiu): get index of the tensor shard
|
||||||
|
# e.g. index =
|
||||||
|
index = 0
|
||||||
|
|
||||||
|
# save tensor to file
|
||||||
|
ckpt_file_name = generate_dtensor_file_name(name, index, use_safetensors)
|
||||||
|
ckpt_file_path = output_root_path.joinpath(ckpt_file_name)
|
||||||
|
|
||||||
|
# dtensor ckpt file always contains only one tensor
|
||||||
|
state_dict = {name: tensor}
|
||||||
|
save_state_dict(state_dict, str(ckpt_file_path), use_safetensors)
|
||||||
|
|
||||||
|
# update the weight map
|
||||||
|
# * means all shards
|
||||||
|
ckpt_file_name_in_weight_map = 'dtensor/' + generate_dtensor_file_name(name, '*', use_safetensors)
|
||||||
|
index_file.append_weight_map(name, ckpt_file_name_in_weight_map)
|
||||||
|
|
||||||
|
|
||||||
|
def get_checkpoint_file_suffix(use_safetensors: bool) -> str:
|
||||||
|
"""
|
||||||
|
Get checkpoint file suffix.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
use_safetensors (bool): whether to use safetensors to save the checkpoint.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: checkpoint file suffix.
|
||||||
|
"""
|
||||||
|
if use_safetensors:
|
||||||
|
return '.safetensors'
|
||||||
|
else:
|
||||||
|
return '.bin'
|
||||||
|
|
||||||
|
|
||||||
|
def generate_checkpoint_shard_file_name(index: int,
|
||||||
|
total_number: int,
|
||||||
|
use_safetensors: bool,
|
||||||
|
prefix: str = None) -> str:
|
||||||
|
"""
|
||||||
|
Generate checkpoint shard file name.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
index (int): index of the shard.
|
||||||
|
total_number (int): total number of shards.
|
||||||
|
use_safetensors (bool): whether to use safetensors to save the checkpoint.
|
||||||
|
prefix (str): prefix of the shard file name. Default: None.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: checkpoint shard file name.
|
||||||
|
"""
|
||||||
|
suffix = get_checkpoint_file_suffix(use_safetensors)
|
||||||
|
|
||||||
|
if prefix is None:
|
||||||
|
return f"{index:05d}-of-{total_number:05d}.{suffix}"
|
||||||
|
else:
|
||||||
|
return f"{prefix}-{index:05d}-of-{total_number:05d}.{suffix}"
|
||||||
|
|
||||||
|
|
||||||
|
def generate_dtensor_file_name(param_name: str, index: int, use_safetensors: bool) -> str:
|
||||||
|
"""
|
||||||
|
Generate dtensor file name.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
param_name (str): name of the distributed parameter.
|
||||||
|
index (int): index of the shard.
|
||||||
|
use_safetensors (bool): whether to use safetensors to save the checkpoint.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: dtensor file name.
|
||||||
|
"""
|
||||||
|
suffix = get_checkpoint_file_suffix(use_safetensors)
|
||||||
|
return f'{param_name}.{index}.{suffix}'
|
||||||
|
|
||||||
|
|
||||||
|
def save_state_dict_as_shard(
|
||||||
|
state_dict: dict,
|
||||||
|
checkpoint_path: str,
|
||||||
|
index: int,
|
||||||
|
total_number: int,
|
||||||
|
use_safetensors: bool,
|
||||||
|
prefix: str = None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Save state dict as shard.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state_dict (dict): state dict.
|
||||||
|
checkpoint_path (str): path to the checkpoint file.
|
||||||
|
index (int): index of the shard.
|
||||||
|
total_number (int): total number of shards.
|
||||||
|
prefix (str): prefix of the shard file name.
|
||||||
|
use_safetensors (bool): whether to use safetensors to save the checkpoint.
|
||||||
|
"""
|
||||||
|
# generate the shard name
|
||||||
|
shard_file_name = generate_checkpoint_shard_file_name(index, total_number, use_safetensors, prefix)
|
||||||
|
shard_file_path = Path(checkpoint_path).joinpath(shard_file_name).absolute()
|
||||||
|
|
||||||
|
# save the shard
|
||||||
|
save_state_dict(state_dict, str(shard_file_path), use_safetensors)
|
||||||
|
|
||||||
|
|
||||||
|
# ========================================
|
||||||
|
# Helper functions for loading state dict
|
||||||
|
# ========================================
|
||||||
|
|
||||||
|
|
||||||
|
def has_index_file(checkpoint_path: str) -> Tuple[bool, Optional[Path]]:
|
||||||
|
"""
|
||||||
|
Check whether the checkpoint has an index file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
checkpoint_path (str): path to the checkpoint.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[bool, Optional[Path]]: a tuple of (has_index_file, index_file_path)
|
||||||
|
"""
|
||||||
|
checkpoint_path = Path(checkpoint_path)
|
||||||
|
if checkpoint_path.is_file():
|
||||||
|
# check if it is .index.json
|
||||||
|
if checkpoint_path.name.endswith('.index.json'):
|
||||||
|
return True, checkpoint_path
|
||||||
|
else:
|
||||||
|
return False, None
|
||||||
|
elif checkpoint_path.is_dir():
|
||||||
|
# check if there is only one a file ending with .index.json in this directory
|
||||||
|
index_files = list(checkpoint_path.glob('*.index.json'))
|
||||||
|
|
||||||
|
# if we found a .index.json file, make sure there is only one
|
||||||
|
if len(index_files) > 0:
|
||||||
|
assert len(
|
||||||
|
index_files
|
||||||
|
) == 1, f'Expected to find one .index.json file in {checkpoint_path}, but found {len(index_files)}'
|
||||||
|
|
||||||
|
if len(index_files) == 1:
|
||||||
|
return True, index_files[0]
|
||||||
|
else:
|
||||||
|
return False, None
|
||||||
|
|
||||||
|
|
||||||
|
def load_state_dict(checkpoint_file_path: Path):
|
||||||
|
"""
|
||||||
|
Load state dict from checkpoint.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
checkpoint_file_path (Path): path to the checkpoint file.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: state dict.
|
||||||
|
"""
|
||||||
|
|
||||||
|
assert not is_dtensor_checkpoint(checkpoint_file_path), \
|
||||||
|
f'Cannot load state dict from dtensor checkpoint {checkpoint_file_path}, you should convert the distributed tensors to gathered tensors with our CLI offline.'
|
||||||
|
|
||||||
|
if is_safetensor_checkpoint(checkpoint_file_path):
|
||||||
|
assert is_safetensors_available(), \
|
||||||
|
f'Cannot load state dict from safetensor checkpoint {checkpoint_file_path}, because safetensors is not available. Please install safetensors first with pip install safetensors.'
|
||||||
|
# load with safetensors
|
||||||
|
from safetensors import safe_open
|
||||||
|
state_dict = {}
|
||||||
|
with safe_open(checkpoint_file_path, framework="pt", device="cpu") as f:
|
||||||
|
for k in f.keys():
|
||||||
|
state_dict[k] = f.get_tensor(k)
|
||||||
|
return state_dict
|
||||||
|
|
||||||
|
else:
|
||||||
|
# load with torch
|
||||||
|
return torch.load(checkpoint_file_path)
|
|
@ -9,3 +9,4 @@ fabric
|
||||||
contexttimer
|
contexttimer
|
||||||
ninja
|
ninja
|
||||||
torch>=1.11
|
torch>=1.11
|
||||||
|
safetensors
|
||||||
|
|
|
@ -71,6 +71,29 @@ def check_dataloader_sharding():
|
||||||
batch_to_compare), 'Same number was found across ranks but expected it to be different'
|
batch_to_compare), 'Same number was found across ranks but expected it to be different'
|
||||||
|
|
||||||
|
|
||||||
|
def check_checkpoint_save_and_load():
|
||||||
|
model_fn, data_gen_fn, output_transform_fn, _ = model_zoo['timm_resnet']
|
||||||
|
|
||||||
|
plugin = TorchDDPPlugin()
|
||||||
|
booster = Booster(plugin=plugin)
|
||||||
|
|
||||||
|
model = model_fn()
|
||||||
|
optimizer = SGD(model.parameters(), lr=1e-3)
|
||||||
|
criterion = lambda x: x.mean()
|
||||||
|
data = data_gen_fn()
|
||||||
|
|
||||||
|
data = {k: v.to('cuda') if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__ else v for k, v in data.items()}
|
||||||
|
|
||||||
|
model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion)
|
||||||
|
|
||||||
|
output = model(**data)
|
||||||
|
output = output_transform_fn(output)
|
||||||
|
output_key = list(output.keys())[0]
|
||||||
|
loss = criterion(output[output_key])
|
||||||
|
|
||||||
|
booster.backward(loss, optimizer)
|
||||||
|
|
||||||
|
|
||||||
def run_dist(rank, world_size, port):
|
def run_dist(rank, world_size, port):
|
||||||
# init dist env
|
# init dist env
|
||||||
colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost')
|
colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost')
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
import tempfile
|
import tempfile
|
||||||
|
|
||||||
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
from torch.optim import Adam
|
from torch.optim import Adam
|
||||||
from torchvision.models import resnet18
|
from torchvision.models import resnet18
|
||||||
|
@ -14,7 +15,8 @@ from colossalai.checkpoint_io import GeneralCheckpointIO
|
||||||
# ========
|
# ========
|
||||||
|
|
||||||
|
|
||||||
def test_unsharded_checkpoint():
|
@pytest.mark.parametrize('use_safetensors', [True, False])
|
||||||
|
def test_unsharded_checkpoint(use_safetensors: bool):
|
||||||
# create a model and optimizer
|
# create a model and optimizer
|
||||||
model = resnet18()
|
model = resnet18()
|
||||||
optimizer = Adam(model.parameters(), lr=0.001)
|
optimizer = Adam(model.parameters(), lr=0.001)
|
||||||
|
@ -29,12 +31,16 @@ def test_unsharded_checkpoint():
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
|
||||||
# create a temp file for checkpoint
|
# create a temp file for checkpoint
|
||||||
model_ckpt_tempfile = tempfile.NamedTemporaryFile()
|
if use_safetensors:
|
||||||
|
suffix = ".safetensors"
|
||||||
|
else:
|
||||||
|
suffix = ".bin"
|
||||||
|
model_ckpt_tempfile = tempfile.NamedTemporaryFile(suffix=suffix)
|
||||||
optimizer_ckpt_tempfile = tempfile.NamedTemporaryFile()
|
optimizer_ckpt_tempfile = tempfile.NamedTemporaryFile()
|
||||||
|
|
||||||
# save the model and optimizer
|
# save the model and optimizer
|
||||||
ckpt_io = GeneralCheckpointIO()
|
ckpt_io = GeneralCheckpointIO()
|
||||||
ckpt_io.save_model(model, model_ckpt_tempfile.name)
|
ckpt_io.save_model(model, model_ckpt_tempfile.name, use_safetensors=use_safetensors)
|
||||||
ckpt_io.save_optimizer(optimizer, optimizer_ckpt_tempfile.name)
|
ckpt_io.save_optimizer(optimizer, optimizer_ckpt_tempfile.name)
|
||||||
|
|
||||||
# create new model
|
# create new model
|
||||||
|
@ -68,3 +74,4 @@ def test_unsharded_checkpoint():
|
||||||
# check for model and optimizer state dict recursively
|
# check for model and optimizer state dict recursively
|
||||||
recursive_check(model.state_dict(), new_model.state_dict())
|
recursive_check(model.state_dict(), new_model.state_dict())
|
||||||
recursive_check(optimizer.state_dict(), new_optimizer.state_dict())
|
recursive_check(optimizer.state_dict(), new_optimizer.state_dict())
|
||||||
|
recursive_check(optimizer.state_dict(), new_optimizer.state_dict())
|
||||||
|
|
Loading…
Reference in New Issue