[checkpoint] support huggingface style sharded checkpoint (#3461)

* [checkpoint] support huggingface style sharded checkpoint

* [checkpoint] support huggingface style sharded checkpoint

* [checkpoint] support huggingface style sharded checkpoint

* [checkpoint] support huggingface style sharded checkpoint

* [checkpoint] support huggingface style sharded checkpoint

---------

Co-authored-by: luchen <luchen@luchendeMBP.lan>
pull/3343/head
jiangmingyan 2023-04-06 16:23:39 +08:00 committed by GitHub
parent 6afeb1202a
commit 52a933e175
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 291 additions and 45 deletions

View File

@ -2,37 +2,35 @@ from pathlib import Path
import torch.nn as nn import torch.nn as nn
from torch.optim import Optimizer from torch.optim import Optimizer
import logging
import os
import json
import gc
from .checkpoint_io_base import CheckpointIO from .checkpoint_io_base import CheckpointIO
from .index_file import CheckpointIndexFile from .index_file import CheckpointIndexFile
from .utils import has_index_file, load_state_dict, save_state_dict from .utils import (
has_index_file,
load_state_dict,
save_state_dict,
is_safetensors_available,
shard_checkpoint,
load_shard_state_dict,
load_state_dict_into_model
)
from .utils import SAFE_WEIGHTS_NAME, WEIGHTS_NAME, SAFE_WEIGHTS_INDEX_NAME, WEIGHTS_INDEX_NAME
__all__ = ['GeneralCheckpointIO'] __all__ = ['GeneralCheckpointIO']
class GeneralCheckpointIO(CheckpointIO): class GeneralCheckpointIO(CheckpointIO):
"""
def load_sharded_model(self, model: nn.Module, index_file_path: Path, strict: bool): Checkpoint IO
# load the index file """
index_file = CheckpointIndexFile.from_file(index_file_path)
# iterate over the shard checkpoint files
# and load each
index_file.assert_no_dtensor_checkpoint()
checkpoint_file_list, _ = index_file.get_checkpoint_fileanames()
for shard_file in checkpoint_file_list:
shard_checkpoint = load_state_dict(shard_file)
model.load_state_dict(shard_checkpoint, strict=strict)
def load_unsharded_model(self, model: nn.Module, checkpoint: str, strict: bool): def load_unsharded_model(self, model: nn.Module, checkpoint: str, strict: bool):
checkpoint = load_state_dict(checkpoint) checkpoint = load_state_dict(checkpoint)
model.load_state_dict(checkpoint, strict=strict) model.load_state_dict(checkpoint, strict=strict)
def save_sharded_model(self, model: nn.Module, checkpoint: Path, gather_dtensor: bool, prefix: str,
size_per_shard: int, use_safetensors: bool):
# TODO(FrankLeeeee): implement this method as it can be supported by Huggingface model
raise NotImplementedError("Sharded model checkpoint is not supported yet.")
def save_unsharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, use_safetensors: bool): def save_unsharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, use_safetensors: bool):
state_dict = model.state_dict() state_dict = model.state_dict()
@ -68,3 +66,68 @@ class GeneralCheckpointIO(CheckpointIO):
): ):
# TODO(FrankLeeeee): handle distributed tensors # TODO(FrankLeeeee): handle distributed tensors
save_state_dict(optimizer.state_dict(), checkpoint, use_safetensors=False) save_state_dict(optimizer.state_dict(), checkpoint, use_safetensors=False)
def save_sharded_model(self, model: nn.Module, checkpoint_path: str, gather_dtensor:bool = False,
prefix: str = "", max_shard_size: int = 1024, use_safetensors: bool = False):
"""
implement this method as it can be supported by Huggingface model,
save shard model, save model to multiple files
"""
if os.path.isfile(checkpoint_path):
logging.error(f"Provided path ({checkpoint_path}) should be a directory, not a file")
return
Path(checkpoint_path).mkdir(parents=True, exist_ok=True)
# shard checkpoint
state_dict = model.state_dict()
weights_name = SAFE_WEIGHTS_NAME if use_safetensors else WEIGHTS_NAME
shards, index = shard_checkpoint(state_dict, max_shard_size=max_shard_size, weights_name=weights_name)
# Save the model
for shard_file, shard in shards.items():
checkpoint_file_path = os.path.join(checkpoint_path, shard_file)
save_state_dict(shard, checkpoint_file_path, use_safetensors)
# save index file
save_index_file = SAFE_WEIGHTS_INDEX_NAME if use_safetensors else WEIGHTS_INDEX_NAME
save_index_file = os.path.join(checkpoint_path, save_index_file)
with open(save_index_file, "w", encoding="utf-8") as f:
content = json.dumps(index, indent=2, sort_keys=True) + "\n"
f.write(content)
logging.info(
f"The model is bigger than the maximum size per checkpoint ({max_shard_size}) and is going to be "
f"split in {len(shards)} checkpoint shards. You can find where each parameters has been saved in the "
f"index located at {save_index_file}."
)
def load_sharded_model(self, model: nn.Module, checkpoint_index_file: Path, strict: bool = False, use_safetensors: bool = False):
"""
load shard model, load model from multiple files
"""
use_safetensors = False
if "safetensors" in checkpoint_index_file.name:
use_safetensors = True
if use_safetensors and not is_safetensors_available():
raise ImportError("`safe_serialization` requires the `safetensors` library: `pip install safetensors`.")
# read checkpoint index file
ckpt_index_file = CheckpointIndexFile.from_file(checkpoint_index_file)
checkpoint_files, _ = ckpt_index_file.get_checkpoint_fileanames()
missing_keys = ckpt_index_file.get_all_param_names()
for shard_file in checkpoint_files:
state_dict = load_shard_state_dict(Path(shard_file), use_safetensors)
load_state_dict_into_model(model, state_dict, missing_keys, strict)
del state_dict
gc.collect()
if strict and len(missing_keys) > 0:
error_msgs = 'Missing key(s) in state_dict: {}. '.format(
', '.join('"{}"'.format(k) for k in missing_keys))
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
self.__class__.__name__, "\n\t".join(error_msgs)))

View File

@ -148,3 +148,9 @@ class CheckpointIndexFile:
""" """
ckpt_path = self.weight_map[param_name] ckpt_path = self.weight_map[param_name]
return ckpt_path return ckpt_path
def get_all_param_names(self):
"""
Get all the weight keys.
"""
return list(self.weight_map.keys())

View File

@ -1,13 +1,19 @@
# coding=utf-8
from pathlib import Path from pathlib import Path
from typing import List, Optional, Tuple
import torch import torch
import torch.nn as nn
from typing import List, Dict, Mapping, OrderedDict, Optional, Tuple
from colossalai.tensor.d_tensor.d_tensor import DTensor
SAFE_WEIGHTS_NAME = "model.safetensors"
WEIGHTS_NAME = "model.bin"
SAFE_WEIGHTS_INDEX_NAME = "model.safetensors.index.json"
WEIGHTS_INDEX_NAME = "model.bin.index.json"
# ====================================== # ======================================
# General helper functions # General helper functions
# ====================================== # ======================================
def calculate_tensor_size(tensor: torch.Tensor) -> float: def calculate_tensor_size(tensor: torch.Tensor) -> float:
""" """
Calculate the size of a parameter in MB. Used to compute whether a group of params exceed the shard size. Calculate the size of a parameter in MB. Used to compute whether a group of params exceed the shard size.
@ -68,6 +74,130 @@ def is_safetensor_checkpoint(checkpoint_file_path: str) -> bool:
return False return False
# ======================================
# Helper functions for saving shard file
# ======================================
def shard_checkpoint(state_dict: torch.Tensor, max_shard_size: int = 1024, weights_name: str = WEIGHTS_NAME):
"""
Splits a model state dictionary in sub-checkpoints so that the final size of each sub-checkpoint does not exceed a
given size.
"""
sharded_state_dicts = []
current_block = {}
current_block_size = 0
total_size = 0
for key, weight in state_dict.items():
if type(weight) != DTensor:
weight_size = calculate_tensor_size(weight)
# If this weight is going to tip up over the maximal size, we split.
if current_block_size + weight_size > max_shard_size:
sharded_state_dicts.append(current_block)
current_block = {}
current_block_size = 0
current_block[key] = weight
current_block_size += weight_size
total_size += weight_size
# Add the last block
sharded_state_dicts.append(current_block)
# If we only have one shard, we return it
if len(sharded_state_dicts) == 1:
return {weights_name: sharded_state_dicts[0]}, None
# Otherwise, let's build the index
weight_map = {}
shards = {}
for idx, shard in enumerate(sharded_state_dicts):
shard_file = weights_name.replace(".bin", f"-{idx+1:05d}-of-{len(sharded_state_dicts):05d}.bin")
shard_file = shard_file.replace(
".safetensors", f"-{idx + 1:05d}-of-{len(sharded_state_dicts):05d}.safetensors"
)
shards[shard_file] = shard
for key in shard.keys():
weight_map[key] = shard_file
# Add the metadata
metadata = {"total_size": total_size}
index = {"metadata": metadata, "weight_map": weight_map}
return shards, index
def load_shard_state_dict(checkpoint_file: Path, use_safetensors: bool =False):
"""
load shard state dict into model
"""
if use_safetensors and not checkpoint_file.suffix == ".safetensors":
raise Exception("load the model using `safetensors`, but no file endwith .safetensors")
if use_safetensors:
from safetensors.torch import safe_open
from safetensors.torch import load_file as safe_load_file
with safe_open(checkpoint_file, framework="pt") as f:
metadata = f.metadata()
if metadata["format"] != "pt":
raise NotImplementedError(
f"Conversion from a {metadata['format']} safetensors archive to PyTorch is not implemented yet."
)
return safe_load_file(checkpoint_file)
else:
return torch.load(checkpoint_file)
def load_state_dict_into_model(model: nn.Module, state_dict: torch.Tensor, missing_keys: List, strict: bool = False):
r"""Copies parameters and buffers from :attr:`state_dict` into
this module and its descendants.
Args:
state_dict (dict): a dict containing parameters and
persistent buffers.
"""
if not isinstance(state_dict, Mapping):
raise TypeError("Expected state_dict to be dict-like, got {}.".format(type(state_dict)))
unexpected_keys: List[str] = []
sub_missing_keys: List[str] = []
error_msgs: List[str] = []
# copy state_dict so _load_from_state_dict can modify it
metadata = getattr(state_dict, '_metadata', None)
state_dict = OrderedDict(state_dict)
if metadata is not None:
state_dict._metadata = metadata
def load(module: nn.Module, state_dict, prefix=""):
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
args = (state_dict, prefix, local_metadata, True, [], [], error_msgs)
# Parameters of module and children will start with prefix. We can exit early if there are none in this
# state_dict
if len([key for key in state_dict if key.startswith(prefix)]) > 0:
module._load_from_state_dict(*args)
for name, child in module._modules.items():
if child is not None:
load(child, state_dict, prefix + name + ".")
load(model, state_dict, "")
del load
# deal with missing key
if len(missing_keys) > 0:
deleted_keys = []
for key in missing_keys:
if key not in sub_missing_keys:
deleted_keys.append(key)
for key in deleted_keys:
missing_keys.remove(key)
if strict:
if len(unexpected_keys) > 0:
error_msgs = 'Unexpected key(s) in state_dict: {}. '.format(
', '.join('"{}"'.format(k) for k in unexpected_keys))
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
model.__class__.__name__, "\n\t".join(error_msgs)))
# ====================================== # ======================================
# Helper functions for saving state dict # Helper functions for saving state dict
# ====================================== # ======================================
@ -86,8 +216,8 @@ def save_state_dict(state_dict: dict, checkpoint_file_path: str, use_safetensors
assert is_safetensors_available(), "safetensors is not available." assert is_safetensors_available(), "safetensors is not available."
assert checkpoint_file_path.endswith('.safetensors'), \ assert checkpoint_file_path.endswith('.safetensors'), \
"safetensors only supports .safetensors suffix for checkpoint file." "safetensors only supports .safetensors suffix for checkpoint file."
from safetensors.torch import save_file from safetensors.torch import save_file as safe_save_file
save_file(state_dict, checkpoint_file_path) safe_save_file(state_dict, checkpoint_file_path, metadata={"format": "pt"})
else: else:
torch.save(state_dict, checkpoint_file_path) torch.save(state_dict, checkpoint_file_path)

View File

@ -1,9 +1,12 @@
import tempfile import tempfile
import pytest import pytest
import torch import torch
import logging
from torch.optim import Adam from torch.optim import Adam
from torchvision.models import resnet18 from torchvision.models import resnet18
from pathlib import Path
import os
import subprocess
from colossalai.checkpoint_io import GeneralCheckpointIO from colossalai.checkpoint_io import GeneralCheckpointIO
from colossalai.testing import clear_cache_before_run, parameterize from colossalai.testing import clear_cache_before_run, parameterize
@ -12,7 +15,7 @@ from colossalai.testing import clear_cache_before_run, parameterize
# Note: # Note:
# 1. due to checkpoint IO can be quite slow if tested with all models, we will only test on resnet for now # 1. due to checkpoint IO can be quite slow if tested with all models, we will only test on resnet for now
# 2. we will test on both sharded and unsharded checkpoints # 2. we will test on both sharded and unsharded checkpoints
# 3. TODO(FrankLeeeee): implement sharded checkpoint and test it # 3. implement sharded checkpoint and test it
# ======== # ========
@ -53,12 +56,61 @@ def test_unsharded_checkpoint(use_safetensors: bool):
ckpt_io.load_model(new_model, model_ckpt_tempfile.name) ckpt_io.load_model(new_model, model_ckpt_tempfile.name)
ckpt_io.load_optimizer(new_optimizer, optimizer_ckpt_tempfile.name) ckpt_io.load_optimizer(new_optimizer, optimizer_ckpt_tempfile.name)
# do recursive check for the optimizer state dict
# if the value is a dict, compare its values # check for model and optimizer state dict recursively
# if the value is a list, comapre all elements one-by-one recursive_check(model.state_dict(), new_model.state_dict())
# if the value is a torch.Tensor, use torch.equal recursive_check(optimizer.state_dict(), new_optimizer.state_dict())
# otherwise use assertEqual
def recursive_check(d1, d2): @pytest.mark.parametrize('use_safetensors', [True, False])
def test_sharded_checkpoint(use_safetensors: bool):
# create a model and optimizer
model = resnet18()
optimizer = Adam(model.parameters(), lr=0.001)
# create test data sample
x = torch.randn(1, 3, 224, 224)
# run fwd and bwd
y = model(x)
loss = y.sum()
loss.backward()
optimizer.step()
# create a temp file for checkpoint
if use_safetensors:
suffix = ".safetensors"
SAFE_WEIGHTS_INDEX_NAME = "model.safetensors.index.json"
else:
suffix = ".bin"
WEIGHTS_INDEX_NAME = "model.bin.index.json"
# model_ckpt_dir = tempfile.TemporaryDirectory(suffix=suffix)
model_ckpt_dir = tempfile.TemporaryDirectory()
optimizer_ckpt_tempfile = tempfile.NamedTemporaryFile()
# save the model and optimizer
ckpt_io = GeneralCheckpointIO()
ckpt_io.save_model(model, model_ckpt_dir.name, True, True, "", 10, use_safetensors=use_safetensors)
ckpt_io.save_optimizer(optimizer, optimizer_ckpt_tempfile.name, shard=False)
# create new model
new_model = resnet18()
new_optimizer = Adam(new_model.parameters(), lr=0.001)
ckpt_io.load_model(new_model, str(model_ckpt_dir.name), strict=True)
ckpt_io.load_optimizer(new_optimizer, optimizer_ckpt_tempfile.name)
# check for model and optimizer state dict recursively
recursive_check(model.state_dict(), new_model.state_dict())
recursive_check(optimizer.state_dict(), new_optimizer.state_dict())
# do recursive check for the optimizer state dict
# if the value is a dict, compare its values
# if the value is a list, comapre all elements one-by-one
# if the value is a torch.Tensor, use torch.equal
# otherwise use assertEqual
def recursive_check(d1, d2):
for k, v in d1.items(): for k, v in d1.items():
if isinstance(v, dict): if isinstance(v, dict):
recursive_check(v, d2[k]) recursive_check(v, d2[k])
@ -72,8 +124,3 @@ def test_unsharded_checkpoint(use_safetensors: bool):
assert torch.equal(v, d2[k]) assert torch.equal(v, d2[k])
else: else:
assert v == d2[k] assert v == d2[k]
# check for model and optimizer state dict recursively
recursive_check(model.state_dict(), new_model.state_dict())
recursive_check(optimizer.state_dict(), new_optimizer.state_dict())
recursive_check(optimizer.state_dict(), new_optimizer.state_dict())