2023-03-23 02:53:17 +00:00
|
|
|
from abc import ABC, abstractmethod
|
|
|
|
from pathlib import Path
|
2024-11-14 03:38:10 +00:00
|
|
|
from typing import Dict, Optional, Union
|
2023-03-23 02:53:17 +00:00
|
|
|
|
|
|
|
import torch
|
|
|
|
import torch.nn as nn
|
|
|
|
from torch.optim import Optimizer
|
|
|
|
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
|
|
|
|
|
2023-03-27 02:24:14 +00:00
|
|
|
from colossalai.interface import ModelWrapper
|
2024-11-14 03:38:10 +00:00
|
|
|
from colossalai.logging import get_dist_logger
|
2023-03-27 02:24:14 +00:00
|
|
|
|
2024-01-25 07:48:46 +00:00
|
|
|
from .utils import SAFE_WEIGHTS_NAME, WEIGHTS_NAME, has_index_file
|
2023-04-04 07:23:01 +00:00
|
|
|
|
2023-09-19 06:20:26 +00:00
|
|
|
__all__ = ["CheckpointIO"]
|
2023-03-23 02:53:17 +00:00
|
|
|
|
|
|
|
|
|
|
|
class CheckpointIO(ABC):
|
|
|
|
"""
|
|
|
|
CheckpointIO is the base class for all checkpoint IO classes. It defines the interface for checkpoint IO.
|
|
|
|
|
|
|
|
|
|
|
|
Examples:
|
|
|
|
>>> from colossalai.checkpoint_io import GeneralCheckpointIO
|
|
|
|
>>> checkpoint_io = CheckpointIO()
|
|
|
|
>>>
|
|
|
|
>>> # load model from checkpoint
|
|
|
|
>>> model = checkpoint_io.load_model(model, 'model.pt')
|
|
|
|
>>>
|
2023-04-04 07:23:01 +00:00
|
|
|
>>> # save model to checkpoint, any distributed tensor is gathered by default
|
2023-03-23 02:53:17 +00:00
|
|
|
>>> checkpoint_io.save_model(model, 'model.pt')
|
|
|
|
>>>
|
2023-04-04 07:23:01 +00:00
|
|
|
>>> # if the model contains distributed tensor, and you don't want to gather it
|
|
|
|
>>> # each rank will save its own shard of the distributed tensor
|
|
|
|
>>> checkpoint_io.save_model(model, 'model.pt', gather_dtensor=False)
|
|
|
|
>>>
|
2023-03-23 02:53:17 +00:00
|
|
|
>>> # save model to sharded checkpoints
|
|
|
|
>>> checkpoint_io.save_model(model, './checkpoints/', shard=True)
|
|
|
|
>>>
|
2023-04-04 07:23:01 +00:00
|
|
|
>>> # save model to sharded and assume we don't want to gather distributed tensors
|
|
|
|
>>> checkpoint_io.save_model(model, './checkpoints/', shard=True, gather_dtensor=False)
|
|
|
|
>>>
|
|
|
|
>>> # Note:
|
|
|
|
>>> # 1. we don't support loading from distributed tensors, conversion from distributed tensors
|
|
|
|
>>> # checkpoints to full tensor checkpoint should be done offline via our CLI
|
|
|
|
>>> # 2. you don't have to specify whether the model is sharded or not when loading the model
|
|
|
|
>>> # as it will be automatically detected
|
|
|
|
>>>
|
2023-03-23 02:53:17 +00:00
|
|
|
>>> # load model from sharded checkpoints
|
|
|
|
>>> model = checkpoint_io.load_model(model, './checkpoints/')
|
|
|
|
>>>
|
2023-04-04 07:23:01 +00:00
|
|
|
>>> # load model from unsharded checkpoints
|
|
|
|
>>> model = checkpoint_io.load_model(model, './checkpoints/')
|
|
|
|
>>>
|
2023-03-23 02:53:17 +00:00
|
|
|
>>> # load optimizer from checkpoint
|
|
|
|
>>> optimizer = checkpoint_io.load_optimizer(optimizer, 'optimizer.pt')
|
|
|
|
>>>
|
|
|
|
>>> # save optimizer to checkpoint
|
|
|
|
>>> checkpoint_io.save_optimizer(optimizer, 'optimizer.pt')
|
|
|
|
"""
|
|
|
|
|
|
|
|
# ======================================
|
2023-03-27 02:24:14 +00:00
|
|
|
# Public methods
|
2023-03-23 02:53:17 +00:00
|
|
|
# ======================================
|
2024-11-14 03:38:10 +00:00
|
|
|
def __init__(self):
|
|
|
|
super().__init__()
|
|
|
|
self.pinned_state_dicts: Dict[int, dict] = {}
|
|
|
|
self.async_writers = []
|
|
|
|
|
|
|
|
def _sync_io(self):
|
|
|
|
for writer in self.async_writers:
|
|
|
|
writer.synchronize()
|
|
|
|
self.async_writers.clear()
|
|
|
|
|
|
|
|
def _sync_d2h(self):
|
|
|
|
for writer in self.async_writers:
|
|
|
|
writer.sync_before_step()
|
|
|
|
|
|
|
|
def synchronize(self):
|
|
|
|
"""This method must be called before updating the model weights."""
|
|
|
|
self._sync_d2h()
|
|
|
|
|
|
|
|
def __del__(self):
|
|
|
|
self._sync_d2h()
|
|
|
|
self._sync_io()
|
|
|
|
|
2023-09-19 06:20:26 +00:00
|
|
|
def load_model(
|
2024-12-25 09:03:25 +00:00
|
|
|
self,
|
|
|
|
model: Union[nn.Module, ModelWrapper],
|
|
|
|
checkpoint: str,
|
|
|
|
strict: bool = True,
|
|
|
|
low_cpu_mem_mode: bool = True,
|
|
|
|
num_threads: int = 1,
|
2023-09-19 06:20:26 +00:00
|
|
|
) -> Union[nn.Module, ModelWrapper]:
|
2023-03-23 02:53:17 +00:00
|
|
|
"""
|
|
|
|
Load model from checkpoint.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
model (nn.Module): model to be loaded.
|
2023-04-26 03:38:43 +00:00
|
|
|
checkpoint (str): checkpoint path. This value is made compatibility with the model checkpoints in the
|
2023-03-23 02:53:17 +00:00
|
|
|
mainstream model zoos such as Hugging Face and TIMM. The checkpoint path can be:
|
|
|
|
1. a file path, e.g. 'model.pt'
|
|
|
|
2. a path to a json file which defines the index to the sharded checkpoint
|
|
|
|
3. a path to a folder containing a unique .index.json file for sharded checkpoint
|
2023-04-04 07:23:01 +00:00
|
|
|
Distributed tensors cannot be loaded directly unless gathered offline via our CLI.
|
2023-03-23 02:53:17 +00:00
|
|
|
strict (bool): whether to strictly enforce that the param name in
|
|
|
|
the checkpoint match the keys returned by this module's.
|
2024-12-25 09:03:25 +00:00
|
|
|
low_cpu_mem_mode (bool): whether to load the model in low cpu memory mode. If false, it will use RAM cache to accelerate loading. Default: True.
|
|
|
|
num_threads (int): number of threads to use when loading the model. Only useful when disabling low cpu mem mode. Default: 1.
|
2023-03-23 02:53:17 +00:00
|
|
|
"""
|
2023-04-04 07:23:01 +00:00
|
|
|
# since we only support loaded sharded and unsharded weight format
|
|
|
|
# containing no distributed tensors, dtensor -> full tensor conversion
|
|
|
|
# should be done offline via our CLI
|
|
|
|
# the existence of index file means it is a sharded checkpoint
|
|
|
|
index_file_exists, index_file_path = has_index_file(checkpoint)
|
2023-05-18 12:05:59 +00:00
|
|
|
|
2023-04-04 07:23:01 +00:00
|
|
|
# return the origin model instead of the unwrapped model
|
2023-03-27 02:24:14 +00:00
|
|
|
origin_model = model
|
|
|
|
|
2023-04-04 07:23:01 +00:00
|
|
|
if index_file_exists:
|
2024-12-25 09:03:25 +00:00
|
|
|
self.load_sharded_model(
|
|
|
|
model, index_file_path, strict, low_cpu_mem_mode=low_cpu_mem_mode, num_threads=num_threads
|
|
|
|
)
|
2023-03-27 02:24:14 +00:00
|
|
|
else:
|
2024-01-25 07:48:46 +00:00
|
|
|
path = Path(checkpoint, SAFE_WEIGHTS_NAME)
|
|
|
|
if path.is_file():
|
2024-12-25 09:03:25 +00:00
|
|
|
self.load_unsharded_model(
|
|
|
|
model, str(path), strict, low_cpu_mem_mode=low_cpu_mem_mode, num_threads=num_threads
|
|
|
|
)
|
2024-01-25 07:48:46 +00:00
|
|
|
else:
|
|
|
|
path = Path(checkpoint, WEIGHTS_NAME)
|
|
|
|
if path.is_file():
|
2024-12-25 09:03:25 +00:00
|
|
|
self.load_unsharded_model(
|
|
|
|
model, str(path), strict, low_cpu_mem_mode=low_cpu_mem_mode, num_threads=num_threads
|
|
|
|
)
|
2024-01-25 07:48:46 +00:00
|
|
|
else:
|
2024-12-25 09:03:25 +00:00
|
|
|
self.load_unsharded_model(
|
|
|
|
model, checkpoint, strict, low_cpu_mem_mode=low_cpu_mem_mode, num_threads=num_threads
|
|
|
|
)
|
2023-03-27 02:24:14 +00:00
|
|
|
|
|
|
|
return origin_model
|
2023-03-23 02:53:17 +00:00
|
|
|
|
2023-09-19 06:20:26 +00:00
|
|
|
def save_model(
|
|
|
|
self,
|
|
|
|
model: Union[nn.Module, ModelWrapper],
|
|
|
|
checkpoint: str,
|
|
|
|
shard: bool = False,
|
|
|
|
gather_dtensor: bool = True,
|
|
|
|
prefix: str = None,
|
|
|
|
size_per_shard: int = 1024,
|
|
|
|
use_safetensors: bool = False,
|
2024-11-14 03:38:10 +00:00
|
|
|
use_async: bool = False,
|
2023-09-19 06:20:26 +00:00
|
|
|
):
|
2023-03-23 02:53:17 +00:00
|
|
|
"""
|
|
|
|
Save model to checkpoint.
|
|
|
|
|
|
|
|
Examples:
|
|
|
|
>>> from colossalai.checkpoint_io import GeneralCheckpointIO
|
|
|
|
>>> checkpoint_io = CheckpointIO()
|
|
|
|
>>>
|
|
|
|
>>> # save model to a single file
|
|
|
|
>>> save_model(model, 'model.pt')
|
|
|
|
>>>
|
|
|
|
>>> # save model to a sharded checkpoint
|
|
|
|
>>> save_model(model, './checkpoints/', shard=True)
|
|
|
|
|
|
|
|
Args:
|
|
|
|
model (nn.Module): model to be saved.
|
2023-03-27 02:24:14 +00:00
|
|
|
checkpoint (str): checkpoint path. The checkpoint path can be :
|
2023-03-23 02:53:17 +00:00
|
|
|
1. a file path, e.g. 'model.pt'
|
|
|
|
2. a directory path to save the sharded checkpoint, e.g. './checkpoints/' when shard = True.
|
2023-03-27 02:24:14 +00:00
|
|
|
shard (bool): whether to shard the checkpoint. Default: False. If set to True, the checkpoint will be sharded into
|
2023-04-26 03:38:43 +00:00
|
|
|
multiple files. The model shards will be specified by a `model.index.json` file. When shard = True, please ensure
|
2023-03-23 02:53:17 +00:00
|
|
|
that the checkpoint path is a directory path instead of a file path.
|
2023-04-04 07:23:01 +00:00
|
|
|
gather_dtensor (bool): whether to gather the distributed tensor to the first device. Default: True.
|
2023-06-15 07:21:26 +00:00
|
|
|
prefix (str): If specified, weights are saved in the format pytorch_model.<prefix>.bin. Default: None.
|
2023-03-27 02:24:14 +00:00
|
|
|
size_per_shard (int): size per shard in MB. Default: 1024. This value is only used when shard = True.
|
2023-04-04 07:23:01 +00:00
|
|
|
use_safetensors (bool): whether to use safe tensors. Default: False. If set to True, the checkpoint will be saved
|
2023-03-23 02:53:17 +00:00
|
|
|
"""
|
2024-11-14 03:38:10 +00:00
|
|
|
self._sync_io()
|
|
|
|
if use_async and not use_safetensors:
|
|
|
|
logger = get_dist_logger()
|
|
|
|
logger.warning(
|
|
|
|
"Async save is only supported when use_safetensors is set to True. "
|
|
|
|
"Setting use_safetensors to True for async save."
|
|
|
|
)
|
|
|
|
use_safetensors = True
|
2023-03-23 02:53:17 +00:00
|
|
|
|
2023-03-27 02:24:14 +00:00
|
|
|
if shard:
|
2024-11-14 03:38:10 +00:00
|
|
|
self.save_sharded_model(
|
2024-11-15 10:19:16 +00:00
|
|
|
model, checkpoint, gather_dtensor, prefix, size_per_shard, use_safetensors, use_async
|
2024-11-14 03:38:10 +00:00
|
|
|
)
|
2023-03-27 02:24:14 +00:00
|
|
|
else:
|
2024-11-15 10:19:16 +00:00
|
|
|
self.save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors, use_async)
|
2023-03-27 02:24:14 +00:00
|
|
|
|
2024-12-25 09:03:25 +00:00
|
|
|
def load_optimizer(
|
|
|
|
self,
|
|
|
|
optimizer: Optimizer,
|
|
|
|
checkpoint: str,
|
|
|
|
prefix: str = None,
|
|
|
|
low_cpu_mem_mode: bool = True,
|
|
|
|
num_threads: int = 1,
|
|
|
|
):
|
2023-03-23 02:53:17 +00:00
|
|
|
"""
|
|
|
|
Load optimizer from checkpoint.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
optimizer (Optimizer): optimizer to be loaded.
|
2023-04-26 03:38:43 +00:00
|
|
|
checkpoint (str): checkpoint path. This value is made compatibility with the model checkpoints in the
|
2023-06-16 06:14:05 +00:00
|
|
|
prefix (str, optional): A prefix added to parameter and buffer
|
|
|
|
names to compose the keys in state_dict. Defaults to None.
|
2024-12-25 09:03:25 +00:00
|
|
|
low_cpu_mem_mode (bool): whether to load the model in low cpu memory mode. If false, it will use RAM cache to accelerate loading. Default: True.
|
|
|
|
num_threads (int): number of threads to use when loading the model. Only useful when disabling low cpu mem mode. Default: 1.
|
2023-03-23 02:53:17 +00:00
|
|
|
"""
|
2023-07-07 08:33:06 +00:00
|
|
|
|
2023-04-04 07:23:01 +00:00
|
|
|
index_file_exists, index_file_path = has_index_file(checkpoint)
|
2023-03-23 02:53:17 +00:00
|
|
|
|
2023-04-04 07:23:01 +00:00
|
|
|
if Path(checkpoint).is_dir() and not index_file_exists:
|
|
|
|
# if the checkpoint is a directory and there is no index file, raise error
|
2023-09-19 06:20:26 +00:00
|
|
|
raise ValueError(f"Cannot find index file in {checkpoint}")
|
2023-04-04 07:23:01 +00:00
|
|
|
|
|
|
|
if index_file_exists:
|
|
|
|
# the existence of index file means it is a sharded checkpoint
|
2024-12-25 09:03:25 +00:00
|
|
|
self.load_sharded_optimizer(
|
|
|
|
optimizer, index_file_path, prefix, low_cpu_mem_mode=low_cpu_mem_mode, num_threads=num_threads
|
|
|
|
)
|
2023-03-27 02:24:14 +00:00
|
|
|
else:
|
2024-12-25 09:03:25 +00:00
|
|
|
self.load_unsharded_optimizer(
|
|
|
|
optimizer, checkpoint, low_cpu_mem_mode=low_cpu_mem_mode, num_threads=num_threads
|
|
|
|
)
|
2023-03-27 02:24:14 +00:00
|
|
|
|
2023-09-19 06:20:26 +00:00
|
|
|
def save_optimizer(
|
|
|
|
self,
|
|
|
|
optimizer: Optimizer,
|
|
|
|
checkpoint: str,
|
|
|
|
shard: bool = False,
|
|
|
|
gather_dtensor=True,
|
|
|
|
prefix: str = None,
|
|
|
|
size_per_shard: int = 1024,
|
2024-11-18 09:52:24 +00:00
|
|
|
use_async: bool = False,
|
2023-09-19 06:20:26 +00:00
|
|
|
):
|
2023-03-23 02:53:17 +00:00
|
|
|
"""
|
2023-04-04 07:23:01 +00:00
|
|
|
Save optimizer to checkpoint. Optimizer states saving is not compatible with safetensors.
|
2023-03-23 02:53:17 +00:00
|
|
|
|
|
|
|
Args:
|
|
|
|
optimizer (Optimizer): optimizer to be saved.
|
2023-03-27 02:24:14 +00:00
|
|
|
checkpoint (str): checkpoint path. The checkpoint path can be :
|
2023-03-23 02:53:17 +00:00
|
|
|
1. a file path, e.g. 'model.pt'
|
|
|
|
2. a path to a json file which defines the index to the sharded checkpoint for the optimizer
|
|
|
|
3. a path to a folder containing a unique .index.json file for sharded checkpoint
|
2023-03-27 02:24:14 +00:00
|
|
|
shard (bool): whether to shard the checkpoint. Default: False. If set to True, the checkpoint will be sharded into
|
2023-04-26 03:38:43 +00:00
|
|
|
multiple files. The optimizer shards will be specified by a `optimizer.index.json` file.
|
2023-04-04 07:23:01 +00:00
|
|
|
gather_dtensor (bool): whether to gather the distributed tensor to the first device. Default: True.
|
2023-03-27 02:24:14 +00:00
|
|
|
prefix (str): prefix for the optimizer checkpoint when shard = True. Default: None.
|
|
|
|
size_per_shard (int): size per shard in MB. Default: 1024. This value is only used when shard is set to True.
|
|
|
|
"""
|
|
|
|
if shard:
|
2024-11-18 09:52:24 +00:00
|
|
|
self.save_sharded_optimizer(
|
|
|
|
optimizer, checkpoint, gather_dtensor, prefix, size_per_shard, use_async=use_async
|
|
|
|
)
|
2023-03-27 02:24:14 +00:00
|
|
|
else:
|
2024-11-18 09:52:24 +00:00
|
|
|
self.save_unsharded_optimizer(optimizer, checkpoint, gather_dtensor, use_async=use_async)
|
2023-03-27 02:24:14 +00:00
|
|
|
|
|
|
|
# ========================================================
|
|
|
|
# Abstract methods for model loading/saving implementation
|
|
|
|
# ========================================================
|
|
|
|
@abstractmethod
|
2024-12-25 09:03:25 +00:00
|
|
|
def load_sharded_model(
|
|
|
|
self, model: nn.Module, index_file_path: str, strict: bool, low_cpu_mem_mode: bool = True, num_threads: int = 1
|
|
|
|
):
|
2023-03-27 02:24:14 +00:00
|
|
|
"""
|
|
|
|
Load model from sharded checkpoint.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
model (nn.Module): model to be loaded.
|
2023-04-04 07:23:01 +00:00
|
|
|
index_file_path (str): checkpoint path. It should be path to the .index.json file or a path to a directory which contains a .index.json file.
|
|
|
|
strict (bool): whether to strictly enforce that the param name in
|
|
|
|
the checkpoint match the keys returned by this module's.
|
2024-12-25 09:03:25 +00:00
|
|
|
low_cpu_mem_mode (bool): whether to load the model in low cpu memory mode. If false, it will use RAM cache to accelerate loading. Default: True.
|
|
|
|
num_threads (int): number of threads to use when loading the model. Only useful when disabling low cpu mem mode. Default: 1.
|
2023-03-27 02:24:14 +00:00
|
|
|
"""
|
|
|
|
|
|
|
|
@abstractmethod
|
2024-12-25 09:03:25 +00:00
|
|
|
def load_unsharded_model(
|
|
|
|
self, model: nn.Module, checkpoint: str, strict: bool, low_cpu_mem_mode: bool = True, num_threads: int = 1
|
|
|
|
):
|
2023-03-27 02:24:14 +00:00
|
|
|
"""
|
|
|
|
Load model from unsharded checkpoint.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
model (nn.Module): model to be loaded.
|
|
|
|
checkpoint (str): checkpoint path. It should be a single file path pointing to a model weight binary.
|
|
|
|
strict (bool): whether to strictly enforce that the param name in
|
|
|
|
the checkpoint match the keys returned by this module's.
|
2024-12-25 09:03:25 +00:00
|
|
|
low_cpu_mem_mode (bool): whether to load the model in low cpu memory mode. If false, it will use RAM cache to accelerate loading. Default: True.
|
|
|
|
num_threads (int): number of threads to use when loading the model. Only useful when disabling low cpu mem mode. Default: 1.
|
2023-03-27 02:24:14 +00:00
|
|
|
"""
|
|
|
|
|
|
|
|
@abstractmethod
|
2023-09-19 06:20:26 +00:00
|
|
|
def save_sharded_model(
|
|
|
|
self,
|
|
|
|
model: nn.Module,
|
|
|
|
checkpoint: str,
|
|
|
|
gather_dtensor: bool,
|
|
|
|
prefix: Optional[str],
|
|
|
|
size_per_shard: int,
|
|
|
|
use_safetensors: bool,
|
2024-11-14 03:38:10 +00:00
|
|
|
use_async: bool = False,
|
2023-09-19 06:20:26 +00:00
|
|
|
):
|
2023-03-27 02:24:14 +00:00
|
|
|
"""
|
|
|
|
Save model to sharded checkpoint.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
model (nn.Module): model to be saved.
|
2023-04-04 07:23:01 +00:00
|
|
|
checkpoint (str): checkpoint path. It should be a directory path.
|
|
|
|
gather_dtensor (bool): whether to gather the distributed tensor to the first device.
|
2023-03-27 02:24:14 +00:00
|
|
|
prefix (str): prefix for the model checkpoint.
|
|
|
|
size_per_shard (int): size per shard in MB.
|
2023-04-04 07:23:01 +00:00
|
|
|
use_safetensors (bool): whether to use safe tensors.
|
2023-03-27 02:24:14 +00:00
|
|
|
"""
|
|
|
|
|
|
|
|
@abstractmethod
|
2024-11-14 03:38:10 +00:00
|
|
|
def save_unsharded_model(
|
|
|
|
self, model: nn.Module, checkpoint: str, gather_dtensor: bool, use_safetensors: bool, use_async: bool = False
|
|
|
|
):
|
2023-03-27 02:24:14 +00:00
|
|
|
"""
|
|
|
|
Save model to unsharded checkpoint.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
model (nn.Module): model to be saved.
|
2023-04-04 07:23:01 +00:00
|
|
|
checkpoint (str): checkpoint path. It should be a single file path pointing to a model weight binary.
|
|
|
|
gather_dtensor (bool): whether to gather the distributed tensor to the first device.
|
|
|
|
use_safetensors (bool): whether to use safe tensors.
|
2023-03-27 02:24:14 +00:00
|
|
|
"""
|
|
|
|
|
|
|
|
# ========================================================
|
|
|
|
# Abstract methods for optimizer loading/saving implementation
|
|
|
|
# ========================================================
|
|
|
|
|
|
|
|
@abstractmethod
|
2024-12-25 09:03:25 +00:00
|
|
|
def load_sharded_optimizer(
|
|
|
|
self,
|
|
|
|
optimizer: Optimizer,
|
|
|
|
index_file_path: str,
|
|
|
|
prefix: str,
|
|
|
|
low_cpu_mem_mode: bool = True,
|
|
|
|
num_threads: int = 1,
|
|
|
|
):
|
2023-03-27 02:24:14 +00:00
|
|
|
"""
|
|
|
|
Load optimizer from sharded checkpoint.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
optimizer (Optimizer): optimizer to be loaded.
|
2023-04-04 07:23:01 +00:00
|
|
|
index_file_path (str): checkpoint path. It should be path to the .index.json file or a path to a directory which contains a .index.json file.
|
2023-03-27 02:24:14 +00:00
|
|
|
prefix (str): prefix for the optimizer checkpoint.
|
2024-12-25 09:03:25 +00:00
|
|
|
low_cpu_mem_mode (bool): whether to load the model in low cpu memory mode. If false, it will use RAM cache to accelerate loading. Default: True.
|
|
|
|
num_threads (int): number of threads to use when loading the model. Only useful when disabling low cpu mem mode. Default: 1.
|
2023-03-27 02:24:14 +00:00
|
|
|
"""
|
|
|
|
|
|
|
|
@abstractmethod
|
2024-12-25 09:03:25 +00:00
|
|
|
def load_unsharded_optimizer(
|
|
|
|
self, optimizer: Optimizer, checkpoint: Path, low_cpu_mem_mode: bool = True, num_threads: int = 1
|
|
|
|
):
|
2023-03-27 02:24:14 +00:00
|
|
|
"""
|
|
|
|
Load optimizer from unsharded checkpoint.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
optimizer (Optimizer): optimizer to be loaded.
|
|
|
|
checkpoint (str): checkpoint path. It should be a single file path pointing to a model weight binary.
|
2024-12-25 09:03:25 +00:00
|
|
|
low_cpu_mem_mode (bool): whether to load the model in low cpu memory mode. If false, it will use RAM cache to accelerate loading. Default: True.
|
|
|
|
num_threads (int): number of threads to use when loading the model. Only useful when disabling low cpu mem mode. Default: 1.
|
2023-03-27 02:24:14 +00:00
|
|
|
"""
|
|
|
|
|
|
|
|
@abstractmethod
|
2023-09-19 06:20:26 +00:00
|
|
|
def save_sharded_optimizer(
|
2024-11-18 09:52:24 +00:00
|
|
|
self,
|
|
|
|
optimizer: Optimizer,
|
|
|
|
checkpoint: Path,
|
|
|
|
gather_dtensor: bool,
|
|
|
|
prefix: str,
|
|
|
|
size_per_shard: int,
|
|
|
|
use_async: bool = False,
|
2023-09-19 06:20:26 +00:00
|
|
|
):
|
2023-03-27 02:24:14 +00:00
|
|
|
"""
|
|
|
|
Save optimizer to sharded checkpoint.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
optimizer (Optimizer): optimizer to be saved.
|
|
|
|
checkpoint (Path): checkpoint path. It should be a directory path.
|
2023-04-04 07:23:01 +00:00
|
|
|
gather_dtensor (bool): whether to gather the distributed tensor to the first device.
|
2023-03-27 02:24:14 +00:00
|
|
|
prefix (str): prefix for the optimizer checkpoint.
|
|
|
|
size_per_shard (int): size per shard in MB.
|
|
|
|
"""
|
|
|
|
|
|
|
|
@abstractmethod
|
2024-11-18 09:52:24 +00:00
|
|
|
def save_unsharded_optimizer(
|
|
|
|
self, optimizer: Optimizer, checkpoint: Path, gather_dtensor: bool, use_async: bool = False
|
|
|
|
):
|
2023-03-27 02:24:14 +00:00
|
|
|
"""
|
|
|
|
Save optimizer to unsharded checkpoint.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
optimizer (Optimizer): optimizer to be saved.
|
|
|
|
checkpoint (str): checkpoint path. It should be a single file path pointing to a model weight binary.
|
2023-04-04 07:23:01 +00:00
|
|
|
gather_dtensor (bool): whether to gather the distributed tensor to the first device.
|
2023-03-23 02:53:17 +00:00
|
|
|
"""
|
|
|
|
|
|
|
|
# ============================================
|
|
|
|
# methods for loading and saving lr scheduler
|
|
|
|
# as this is quite standard, there is no need
|
|
|
|
# to make them abstract
|
|
|
|
# ============================================
|
|
|
|
def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str):
|
|
|
|
"""
|
|
|
|
Save lr scheduler to checkpoint.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
lr_scheduler (LRScheduler): lr scheduler to be saved.
|
|
|
|
checkpoint: checkpoint path. The checkpoint path can only be a file path.
|
|
|
|
"""
|
|
|
|
torch.save(lr_scheduler.state_dict(), checkpoint)
|
|
|
|
|
|
|
|
def load_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str):
|
|
|
|
"""
|
|
|
|
Load lr scheduler from checkpoint.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
lr_scheduler (LRScheduler): lr scheduler to be loaded.
|
|
|
|
checkpoint (str): the path for a single checkpoint file.
|
|
|
|
"""
|
|
|
|
state_dict = torch.load(checkpoint)
|
|
|
|
lr_scheduler.load_state_dict(state_dict)
|
2023-10-31 07:19:37 +00:00
|
|
|
|
|
|
|
# ================================================================================
|
|
|
|
# Abstract method for lora saving implementation.
|
|
|
|
# ================================================================================
|
|
|
|
|
|
|
|
@abstractmethod
|
|
|
|
def save_lora_as_pretrained(
|
|
|
|
self, model: Union[nn.Module, ModelWrapper], checkpoint: str, use_safetensors: bool = False
|
|
|
|
) -> None:
|
|
|
|
"""
|
|
|
|
Save the lora adapters and adapter configuration file to a pretrained checkpoint directory.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
model (Union[nn.Module, ModelWrapper]): A model boosted by Booster.
|
|
|
|
checkpoint (str): Path to the checkpoint directory. It must be a local path.
|
|
|
|
use_safetensors (bool, optional): Whether to use safe tensors when saving. Defaults to False.
|
|
|
|
"""
|