from pathlib import Path import torch.nn as nn from torch.optim import Optimizer from .checkpoint_io_base import CheckpointIO from .index_file import CheckpointIndexFile from .utils import has_index_file, load_state_dict, save_state_dict __all__ = ['GeneralCheckpointIO'] class GeneralCheckpointIO(CheckpointIO): def load_sharded_model(self, model: nn.Module, index_file_path: Path, strict: bool): # load the index file index_file = CheckpointIndexFile.from_file(index_file_path) # iterate over the shard checkpoint files # and load each index_file.assert_no_dtensor_checkpoint() checkpoint_file_list, _ = index_file.get_checkpoint_fileanames() for shard_file in checkpoint_file_list: shard_checkpoint = load_state_dict(shard_file) model.load_state_dict(shard_checkpoint, strict=strict) def load_unsharded_model(self, model: nn.Module, checkpoint: str, strict: bool): checkpoint = load_state_dict(checkpoint) model.load_state_dict(checkpoint, strict=strict) def save_sharded_model(self, model: nn.Module, checkpoint: Path, gather_dtensor: bool, prefix: str, size_per_shard: int, use_safetensors: bool): # TODO(FrankLeeeee): implement this method as it can be supported by Huggingface model raise NotImplementedError("Sharded model checkpoint is not supported yet.") def save_unsharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, use_safetensors: bool): state_dict = model.state_dict() # TODO(FrankLeeeee): add support for gather_dtensor if gather_dtensor: pass # save the checkpoint save_state_dict(state_dict, checkpoint, use_safetensors) def load_sharded_optimizer(self, optimizer: Optimizer, checkpoint: Path, prefix: str, size_per_shard: int): raise NotImplementedError("Sharded optimizer checkpoint is not supported yet.") def load_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path): checkpoint = load_state_dict(checkpoint) optimizer.load_state_dict(checkpoint) def save_sharded_optimizer( self, optimizer: Optimizer, checkpoint: Path, gather_dtensor: bool, prefix: str, size_per_shard: int, ): raise NotImplementedError("Sharded optimizer checkpoint is not supported yet.") def save_unsharded_optimizer( self, optimizer: Optimizer, checkpoint: Path, gather_dtensor: bool, ): # TODO(FrankLeeeee): handle distributed tensors save_state_dict(optimizer.state_dict(), checkpoint, use_safetensors=False)