mirror of https://github.com/hpcaitech/ColossalAI
[checkpoint] support huggingface style sharded checkpoint (#3461)
* [checkpoint] support huggingface style sharded checkpoint * [checkpoint] support huggingface style sharded checkpoint * [checkpoint] support huggingface style sharded checkpoint * [checkpoint] support huggingface style sharded checkpoint * [checkpoint] support huggingface style sharded checkpoint --------- Co-authored-by: luchen <luchen@luchendeMBP.lan>pull/3343/head
parent
6afeb1202a
commit
52a933e175
|
@ -2,37 +2,35 @@ from pathlib import Path
|
||||||
|
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch.optim import Optimizer
|
from torch.optim import Optimizer
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
import gc
|
||||||
|
|
||||||
from .checkpoint_io_base import CheckpointIO
|
from .checkpoint_io_base import CheckpointIO
|
||||||
from .index_file import CheckpointIndexFile
|
from .index_file import CheckpointIndexFile
|
||||||
from .utils import has_index_file, load_state_dict, save_state_dict
|
from .utils import (
|
||||||
|
has_index_file,
|
||||||
|
load_state_dict,
|
||||||
|
save_state_dict,
|
||||||
|
is_safetensors_available,
|
||||||
|
shard_checkpoint,
|
||||||
|
load_shard_state_dict,
|
||||||
|
load_state_dict_into_model
|
||||||
|
)
|
||||||
|
from .utils import SAFE_WEIGHTS_NAME, WEIGHTS_NAME, SAFE_WEIGHTS_INDEX_NAME, WEIGHTS_INDEX_NAME
|
||||||
|
|
||||||
__all__ = ['GeneralCheckpointIO']
|
__all__ = ['GeneralCheckpointIO']
|
||||||
|
|
||||||
|
|
||||||
class GeneralCheckpointIO(CheckpointIO):
|
class GeneralCheckpointIO(CheckpointIO):
|
||||||
|
"""
|
||||||
def load_sharded_model(self, model: nn.Module, index_file_path: Path, strict: bool):
|
Checkpoint IO
|
||||||
# load the index file
|
"""
|
||||||
index_file = CheckpointIndexFile.from_file(index_file_path)
|
|
||||||
|
|
||||||
# iterate over the shard checkpoint files
|
|
||||||
# and load each
|
|
||||||
index_file.assert_no_dtensor_checkpoint()
|
|
||||||
checkpoint_file_list, _ = index_file.get_checkpoint_fileanames()
|
|
||||||
for shard_file in checkpoint_file_list:
|
|
||||||
shard_checkpoint = load_state_dict(shard_file)
|
|
||||||
model.load_state_dict(shard_checkpoint, strict=strict)
|
|
||||||
|
|
||||||
def load_unsharded_model(self, model: nn.Module, checkpoint: str, strict: bool):
|
def load_unsharded_model(self, model: nn.Module, checkpoint: str, strict: bool):
|
||||||
checkpoint = load_state_dict(checkpoint)
|
checkpoint = load_state_dict(checkpoint)
|
||||||
model.load_state_dict(checkpoint, strict=strict)
|
model.load_state_dict(checkpoint, strict=strict)
|
||||||
|
|
||||||
def save_sharded_model(self, model: nn.Module, checkpoint: Path, gather_dtensor: bool, prefix: str,
|
|
||||||
size_per_shard: int, use_safetensors: bool):
|
|
||||||
# TODO(FrankLeeeee): implement this method as it can be supported by Huggingface model
|
|
||||||
raise NotImplementedError("Sharded model checkpoint is not supported yet.")
|
|
||||||
|
|
||||||
def save_unsharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, use_safetensors: bool):
|
def save_unsharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, use_safetensors: bool):
|
||||||
state_dict = model.state_dict()
|
state_dict = model.state_dict()
|
||||||
|
|
||||||
|
@ -68,3 +66,68 @@ class GeneralCheckpointIO(CheckpointIO):
|
||||||
):
|
):
|
||||||
# TODO(FrankLeeeee): handle distributed tensors
|
# TODO(FrankLeeeee): handle distributed tensors
|
||||||
save_state_dict(optimizer.state_dict(), checkpoint, use_safetensors=False)
|
save_state_dict(optimizer.state_dict(), checkpoint, use_safetensors=False)
|
||||||
|
|
||||||
|
|
||||||
|
def save_sharded_model(self, model: nn.Module, checkpoint_path: str, gather_dtensor:bool = False,
|
||||||
|
prefix: str = "", max_shard_size: int = 1024, use_safetensors: bool = False):
|
||||||
|
"""
|
||||||
|
implement this method as it can be supported by Huggingface model,
|
||||||
|
save shard model, save model to multiple files
|
||||||
|
"""
|
||||||
|
if os.path.isfile(checkpoint_path):
|
||||||
|
logging.error(f"Provided path ({checkpoint_path}) should be a directory, not a file")
|
||||||
|
return
|
||||||
|
|
||||||
|
Path(checkpoint_path).mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# shard checkpoint
|
||||||
|
state_dict = model.state_dict()
|
||||||
|
weights_name = SAFE_WEIGHTS_NAME if use_safetensors else WEIGHTS_NAME
|
||||||
|
shards, index = shard_checkpoint(state_dict, max_shard_size=max_shard_size, weights_name=weights_name)
|
||||||
|
|
||||||
|
# Save the model
|
||||||
|
for shard_file, shard in shards.items():
|
||||||
|
checkpoint_file_path = os.path.join(checkpoint_path, shard_file)
|
||||||
|
save_state_dict(shard, checkpoint_file_path, use_safetensors)
|
||||||
|
|
||||||
|
# save index file
|
||||||
|
save_index_file = SAFE_WEIGHTS_INDEX_NAME if use_safetensors else WEIGHTS_INDEX_NAME
|
||||||
|
save_index_file = os.path.join(checkpoint_path, save_index_file)
|
||||||
|
with open(save_index_file, "w", encoding="utf-8") as f:
|
||||||
|
content = json.dumps(index, indent=2, sort_keys=True) + "\n"
|
||||||
|
f.write(content)
|
||||||
|
logging.info(
|
||||||
|
f"The model is bigger than the maximum size per checkpoint ({max_shard_size}) and is going to be "
|
||||||
|
f"split in {len(shards)} checkpoint shards. You can find where each parameters has been saved in the "
|
||||||
|
f"index located at {save_index_file}."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def load_sharded_model(self, model: nn.Module, checkpoint_index_file: Path, strict: bool = False, use_safetensors: bool = False):
|
||||||
|
"""
|
||||||
|
load shard model, load model from multiple files
|
||||||
|
"""
|
||||||
|
use_safetensors = False
|
||||||
|
if "safetensors" in checkpoint_index_file.name:
|
||||||
|
use_safetensors = True
|
||||||
|
|
||||||
|
if use_safetensors and not is_safetensors_available():
|
||||||
|
raise ImportError("`safe_serialization` requires the `safetensors` library: `pip install safetensors`.")
|
||||||
|
|
||||||
|
# read checkpoint index file
|
||||||
|
ckpt_index_file = CheckpointIndexFile.from_file(checkpoint_index_file)
|
||||||
|
checkpoint_files, _ = ckpt_index_file.get_checkpoint_fileanames()
|
||||||
|
missing_keys = ckpt_index_file.get_all_param_names()
|
||||||
|
|
||||||
|
for shard_file in checkpoint_files:
|
||||||
|
state_dict = load_shard_state_dict(Path(shard_file), use_safetensors)
|
||||||
|
load_state_dict_into_model(model, state_dict, missing_keys, strict)
|
||||||
|
del state_dict
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
if strict and len(missing_keys) > 0:
|
||||||
|
error_msgs = 'Missing key(s) in state_dict: {}. '.format(
|
||||||
|
', '.join('"{}"'.format(k) for k in missing_keys))
|
||||||
|
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
|
||||||
|
self.__class__.__name__, "\n\t".join(error_msgs)))
|
||||||
|
|
||||||
|
|
|
@ -148,3 +148,9 @@ class CheckpointIndexFile:
|
||||||
"""
|
"""
|
||||||
ckpt_path = self.weight_map[param_name]
|
ckpt_path = self.weight_map[param_name]
|
||||||
return ckpt_path
|
return ckpt_path
|
||||||
|
|
||||||
|
def get_all_param_names(self):
|
||||||
|
"""
|
||||||
|
Get all the weight keys.
|
||||||
|
"""
|
||||||
|
return list(self.weight_map.keys())
|
||||||
|
|
|
@ -1,13 +1,19 @@
|
||||||
|
# coding=utf-8
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Optional, Tuple
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from typing import List, Dict, Mapping, OrderedDict, Optional, Tuple
|
||||||
|
from colossalai.tensor.d_tensor.d_tensor import DTensor
|
||||||
|
|
||||||
|
SAFE_WEIGHTS_NAME = "model.safetensors"
|
||||||
|
WEIGHTS_NAME = "model.bin"
|
||||||
|
SAFE_WEIGHTS_INDEX_NAME = "model.safetensors.index.json"
|
||||||
|
WEIGHTS_INDEX_NAME = "model.bin.index.json"
|
||||||
|
|
||||||
# ======================================
|
# ======================================
|
||||||
# General helper functions
|
# General helper functions
|
||||||
# ======================================
|
# ======================================
|
||||||
|
|
||||||
|
|
||||||
def calculate_tensor_size(tensor: torch.Tensor) -> float:
|
def calculate_tensor_size(tensor: torch.Tensor) -> float:
|
||||||
"""
|
"""
|
||||||
Calculate the size of a parameter in MB. Used to compute whether a group of params exceed the shard size.
|
Calculate the size of a parameter in MB. Used to compute whether a group of params exceed the shard size.
|
||||||
|
@ -68,6 +74,130 @@ def is_safetensor_checkpoint(checkpoint_file_path: str) -> bool:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
# ======================================
|
||||||
|
# Helper functions for saving shard file
|
||||||
|
# ======================================
|
||||||
|
def shard_checkpoint(state_dict: torch.Tensor, max_shard_size: int = 1024, weights_name: str = WEIGHTS_NAME):
|
||||||
|
|
||||||
|
"""
|
||||||
|
Splits a model state dictionary in sub-checkpoints so that the final size of each sub-checkpoint does not exceed a
|
||||||
|
given size.
|
||||||
|
"""
|
||||||
|
sharded_state_dicts = []
|
||||||
|
current_block = {}
|
||||||
|
current_block_size = 0
|
||||||
|
total_size = 0
|
||||||
|
|
||||||
|
for key, weight in state_dict.items():
|
||||||
|
if type(weight) != DTensor:
|
||||||
|
weight_size = calculate_tensor_size(weight)
|
||||||
|
|
||||||
|
# If this weight is going to tip up over the maximal size, we split.
|
||||||
|
if current_block_size + weight_size > max_shard_size:
|
||||||
|
sharded_state_dicts.append(current_block)
|
||||||
|
current_block = {}
|
||||||
|
current_block_size = 0
|
||||||
|
|
||||||
|
current_block[key] = weight
|
||||||
|
current_block_size += weight_size
|
||||||
|
total_size += weight_size
|
||||||
|
|
||||||
|
# Add the last block
|
||||||
|
sharded_state_dicts.append(current_block)
|
||||||
|
|
||||||
|
# If we only have one shard, we return it
|
||||||
|
if len(sharded_state_dicts) == 1:
|
||||||
|
return {weights_name: sharded_state_dicts[0]}, None
|
||||||
|
|
||||||
|
# Otherwise, let's build the index
|
||||||
|
weight_map = {}
|
||||||
|
shards = {}
|
||||||
|
|
||||||
|
for idx, shard in enumerate(sharded_state_dicts):
|
||||||
|
shard_file = weights_name.replace(".bin", f"-{idx+1:05d}-of-{len(sharded_state_dicts):05d}.bin")
|
||||||
|
shard_file = shard_file.replace(
|
||||||
|
".safetensors", f"-{idx + 1:05d}-of-{len(sharded_state_dicts):05d}.safetensors"
|
||||||
|
)
|
||||||
|
shards[shard_file] = shard
|
||||||
|
for key in shard.keys():
|
||||||
|
weight_map[key] = shard_file
|
||||||
|
|
||||||
|
# Add the metadata
|
||||||
|
metadata = {"total_size": total_size}
|
||||||
|
index = {"metadata": metadata, "weight_map": weight_map}
|
||||||
|
return shards, index
|
||||||
|
|
||||||
|
def load_shard_state_dict(checkpoint_file: Path, use_safetensors: bool =False):
|
||||||
|
"""
|
||||||
|
load shard state dict into model
|
||||||
|
"""
|
||||||
|
if use_safetensors and not checkpoint_file.suffix == ".safetensors":
|
||||||
|
raise Exception("load the model using `safetensors`, but no file endwith .safetensors")
|
||||||
|
if use_safetensors:
|
||||||
|
from safetensors.torch import safe_open
|
||||||
|
from safetensors.torch import load_file as safe_load_file
|
||||||
|
with safe_open(checkpoint_file, framework="pt") as f:
|
||||||
|
metadata = f.metadata()
|
||||||
|
if metadata["format"] != "pt":
|
||||||
|
raise NotImplementedError(
|
||||||
|
f"Conversion from a {metadata['format']} safetensors archive to PyTorch is not implemented yet."
|
||||||
|
)
|
||||||
|
return safe_load_file(checkpoint_file)
|
||||||
|
else:
|
||||||
|
return torch.load(checkpoint_file)
|
||||||
|
|
||||||
|
def load_state_dict_into_model(model: nn.Module, state_dict: torch.Tensor, missing_keys: List, strict: bool = False):
|
||||||
|
r"""Copies parameters and buffers from :attr:`state_dict` into
|
||||||
|
this module and its descendants.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state_dict (dict): a dict containing parameters and
|
||||||
|
persistent buffers.
|
||||||
|
"""
|
||||||
|
if not isinstance(state_dict, Mapping):
|
||||||
|
raise TypeError("Expected state_dict to be dict-like, got {}.".format(type(state_dict)))
|
||||||
|
|
||||||
|
unexpected_keys: List[str] = []
|
||||||
|
sub_missing_keys: List[str] = []
|
||||||
|
error_msgs: List[str] = []
|
||||||
|
|
||||||
|
# copy state_dict so _load_from_state_dict can modify it
|
||||||
|
metadata = getattr(state_dict, '_metadata', None)
|
||||||
|
state_dict = OrderedDict(state_dict)
|
||||||
|
if metadata is not None:
|
||||||
|
state_dict._metadata = metadata
|
||||||
|
|
||||||
|
def load(module: nn.Module, state_dict, prefix=""):
|
||||||
|
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
|
||||||
|
args = (state_dict, prefix, local_metadata, True, [], [], error_msgs)
|
||||||
|
# Parameters of module and children will start with prefix. We can exit early if there are none in this
|
||||||
|
# state_dict
|
||||||
|
if len([key for key in state_dict if key.startswith(prefix)]) > 0:
|
||||||
|
module._load_from_state_dict(*args)
|
||||||
|
|
||||||
|
for name, child in module._modules.items():
|
||||||
|
if child is not None:
|
||||||
|
load(child, state_dict, prefix + name + ".")
|
||||||
|
|
||||||
|
load(model, state_dict, "")
|
||||||
|
del load
|
||||||
|
|
||||||
|
# deal with missing key
|
||||||
|
if len(missing_keys) > 0:
|
||||||
|
deleted_keys = []
|
||||||
|
for key in missing_keys:
|
||||||
|
if key not in sub_missing_keys:
|
||||||
|
deleted_keys.append(key)
|
||||||
|
for key in deleted_keys:
|
||||||
|
missing_keys.remove(key)
|
||||||
|
|
||||||
|
if strict:
|
||||||
|
if len(unexpected_keys) > 0:
|
||||||
|
error_msgs = 'Unexpected key(s) in state_dict: {}. '.format(
|
||||||
|
', '.join('"{}"'.format(k) for k in unexpected_keys))
|
||||||
|
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
|
||||||
|
model.__class__.__name__, "\n\t".join(error_msgs)))
|
||||||
|
|
||||||
# ======================================
|
# ======================================
|
||||||
# Helper functions for saving state dict
|
# Helper functions for saving state dict
|
||||||
# ======================================
|
# ======================================
|
||||||
|
@ -86,8 +216,8 @@ def save_state_dict(state_dict: dict, checkpoint_file_path: str, use_safetensors
|
||||||
assert is_safetensors_available(), "safetensors is not available."
|
assert is_safetensors_available(), "safetensors is not available."
|
||||||
assert checkpoint_file_path.endswith('.safetensors'), \
|
assert checkpoint_file_path.endswith('.safetensors'), \
|
||||||
"safetensors only supports .safetensors suffix for checkpoint file."
|
"safetensors only supports .safetensors suffix for checkpoint file."
|
||||||
from safetensors.torch import save_file
|
from safetensors.torch import save_file as safe_save_file
|
||||||
save_file(state_dict, checkpoint_file_path)
|
safe_save_file(state_dict, checkpoint_file_path, metadata={"format": "pt"})
|
||||||
else:
|
else:
|
||||||
torch.save(state_dict, checkpoint_file_path)
|
torch.save(state_dict, checkpoint_file_path)
|
||||||
|
|
||||||
|
|
|
@ -1,9 +1,12 @@
|
||||||
import tempfile
|
import tempfile
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
import logging
|
||||||
from torch.optim import Adam
|
from torch.optim import Adam
|
||||||
from torchvision.models import resnet18
|
from torchvision.models import resnet18
|
||||||
|
from pathlib import Path
|
||||||
|
import os
|
||||||
|
import subprocess
|
||||||
|
|
||||||
from colossalai.checkpoint_io import GeneralCheckpointIO
|
from colossalai.checkpoint_io import GeneralCheckpointIO
|
||||||
from colossalai.testing import clear_cache_before_run, parameterize
|
from colossalai.testing import clear_cache_before_run, parameterize
|
||||||
|
@ -12,7 +15,7 @@ from colossalai.testing import clear_cache_before_run, parameterize
|
||||||
# Note:
|
# Note:
|
||||||
# 1. due to checkpoint IO can be quite slow if tested with all models, we will only test on resnet for now
|
# 1. due to checkpoint IO can be quite slow if tested with all models, we will only test on resnet for now
|
||||||
# 2. we will test on both sharded and unsharded checkpoints
|
# 2. we will test on both sharded and unsharded checkpoints
|
||||||
# 3. TODO(FrankLeeeee): implement sharded checkpoint and test it
|
# 3. implement sharded checkpoint and test it
|
||||||
# ========
|
# ========
|
||||||
|
|
||||||
|
|
||||||
|
@ -53,12 +56,61 @@ def test_unsharded_checkpoint(use_safetensors: bool):
|
||||||
ckpt_io.load_model(new_model, model_ckpt_tempfile.name)
|
ckpt_io.load_model(new_model, model_ckpt_tempfile.name)
|
||||||
ckpt_io.load_optimizer(new_optimizer, optimizer_ckpt_tempfile.name)
|
ckpt_io.load_optimizer(new_optimizer, optimizer_ckpt_tempfile.name)
|
||||||
|
|
||||||
# do recursive check for the optimizer state dict
|
|
||||||
# if the value is a dict, compare its values
|
# check for model and optimizer state dict recursively
|
||||||
# if the value is a list, comapre all elements one-by-one
|
recursive_check(model.state_dict(), new_model.state_dict())
|
||||||
# if the value is a torch.Tensor, use torch.equal
|
recursive_check(optimizer.state_dict(), new_optimizer.state_dict())
|
||||||
# otherwise use assertEqual
|
|
||||||
def recursive_check(d1, d2):
|
@pytest.mark.parametrize('use_safetensors', [True, False])
|
||||||
|
def test_sharded_checkpoint(use_safetensors: bool):
|
||||||
|
# create a model and optimizer
|
||||||
|
model = resnet18()
|
||||||
|
optimizer = Adam(model.parameters(), lr=0.001)
|
||||||
|
# create test data sample
|
||||||
|
x = torch.randn(1, 3, 224, 224)
|
||||||
|
|
||||||
|
# run fwd and bwd
|
||||||
|
y = model(x)
|
||||||
|
loss = y.sum()
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
# create a temp file for checkpoint
|
||||||
|
if use_safetensors:
|
||||||
|
suffix = ".safetensors"
|
||||||
|
SAFE_WEIGHTS_INDEX_NAME = "model.safetensors.index.json"
|
||||||
|
else:
|
||||||
|
suffix = ".bin"
|
||||||
|
WEIGHTS_INDEX_NAME = "model.bin.index.json"
|
||||||
|
|
||||||
|
# model_ckpt_dir = tempfile.TemporaryDirectory(suffix=suffix)
|
||||||
|
model_ckpt_dir = tempfile.TemporaryDirectory()
|
||||||
|
optimizer_ckpt_tempfile = tempfile.NamedTemporaryFile()
|
||||||
|
|
||||||
|
# save the model and optimizer
|
||||||
|
ckpt_io = GeneralCheckpointIO()
|
||||||
|
|
||||||
|
ckpt_io.save_model(model, model_ckpt_dir.name, True, True, "", 10, use_safetensors=use_safetensors)
|
||||||
|
ckpt_io.save_optimizer(optimizer, optimizer_ckpt_tempfile.name, shard=False)
|
||||||
|
|
||||||
|
# create new model
|
||||||
|
new_model = resnet18()
|
||||||
|
new_optimizer = Adam(new_model.parameters(), lr=0.001)
|
||||||
|
|
||||||
|
ckpt_io.load_model(new_model, str(model_ckpt_dir.name), strict=True)
|
||||||
|
ckpt_io.load_optimizer(new_optimizer, optimizer_ckpt_tempfile.name)
|
||||||
|
|
||||||
|
# check for model and optimizer state dict recursively
|
||||||
|
recursive_check(model.state_dict(), new_model.state_dict())
|
||||||
|
recursive_check(optimizer.state_dict(), new_optimizer.state_dict())
|
||||||
|
|
||||||
|
|
||||||
|
# do recursive check for the optimizer state dict
|
||||||
|
# if the value is a dict, compare its values
|
||||||
|
# if the value is a list, comapre all elements one-by-one
|
||||||
|
# if the value is a torch.Tensor, use torch.equal
|
||||||
|
# otherwise use assertEqual
|
||||||
|
def recursive_check(d1, d2):
|
||||||
for k, v in d1.items():
|
for k, v in d1.items():
|
||||||
if isinstance(v, dict):
|
if isinstance(v, dict):
|
||||||
recursive_check(v, d2[k])
|
recursive_check(v, d2[k])
|
||||||
|
@ -72,8 +124,3 @@ def test_unsharded_checkpoint(use_safetensors: bool):
|
||||||
assert torch.equal(v, d2[k])
|
assert torch.equal(v, d2[k])
|
||||||
else:
|
else:
|
||||||
assert v == d2[k]
|
assert v == d2[k]
|
||||||
|
|
||||||
# check for model and optimizer state dict recursively
|
|
||||||
recursive_check(model.state_dict(), new_model.state_dict())
|
|
||||||
recursive_check(optimizer.state_dict(), new_optimizer.state_dict())
|
|
||||||
recursive_check(optimizer.state_dict(), new_optimizer.state_dict())
|
|
||||||
|
|
Loading…
Reference in New Issue