[checkpointio] support load-pin overlap (#6177)

* [checkpointio] support load-pin overlap

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* [test] add conftest

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
main
Hongxin Liu 2025-01-07 16:16:04 +08:00 committed by GitHub
parent 479067e9bc
commit ee81366cac
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 56 additions and 32 deletions

View File

@ -20,7 +20,7 @@ from colossalai.checkpoint_io.utils import (
create_pinned_state_dict,
get_model_base_filenames,
get_optimizer_base_filenames,
load_shard_state_dict,
load_state_dict_shards,
save_config_file,
save_state_dict,
save_state_dict_shards,
@ -29,7 +29,6 @@ from colossalai.cluster import DistCoordinator, ProcessGroupMesh
from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.logging import get_dist_logger
from colossalai.shardformer import ShardConfig, ShardFormer
from colossalai.utils.safetensors import load_flat
from colossalai.zero import GeminiDDP, GeminiOptimizer
from colossalai.zero.gemini.memory_tracer import MemStats
@ -350,11 +349,9 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
# Load optimizer states from shard files under checkpoint path.
# For each file, only load the states managed by current process.
for shard_file in checkpoint_files:
if shard_file.endswith(".safetensors"):
state_dict_shard = load_flat(shard_file)
else:
state_dict_shard = load_shard_state_dict(Path(shard_file), use_safetensors=False)
for state_dict_shard in load_state_dict_shards(
checkpoint_files, True, False, low_cpu_mem_mode=low_cpu_mem_mode
):
if not low_cpu_mem_mode:
state_dict_shard = create_pinned_state_dict(state_dict_shard, empty=False, num_threads=num_threads)
optimizer.load_param_states(state_dict_shard)

View File

@ -24,8 +24,8 @@ from colossalai.checkpoint_io.utils import (
get_optimizer_base_filenames,
get_shard_filename,
load_param_groups_into_optimizer,
load_shard_state_dict,
load_state_dict,
load_state_dict_shards,
load_states_into_optimizer,
save_param_groups,
save_state_dict,
@ -276,13 +276,7 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
checkpoint_files, _ = ckpt_index_file.get_checkpoint_filenames()
for shard_file in checkpoint_files:
if shard_file.endswith(".safetensors"):
from colossalai.utils.safetensors import load_flat
state_dict = load_flat(shard_file)
else:
state_dict = load_shard_state_dict(Path(shard_file), use_safetensors=False)
for state_dict in load_state_dict_shards(checkpoint_files, True, False, low_cpu_mem_mode):
# shard state dict
for param_idx, state in state_dict.items():
for k, v in state.items():

View File

@ -255,8 +255,8 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO):
checkpoint_files, _ = ckpt_index_file.get_checkpoint_filenames()
fsdp_state_dict = {}
for shard_file in checkpoint_files:
fsdp_state_dict.update(utils.load_shard_state_dict(Path(shard_file), use_safetensors))
for state_dict in utils.load_state_dict_shards(checkpoint_files, False, use_safetensors):
fsdp_state_dict.update(state_dict)
with FSDP.state_dict_type(model.unwrap(), StateDictType.FULL_STATE_DICT):
model.unwrap().load_state_dict(fsdp_state_dict, strict=False)
@ -388,11 +388,7 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO):
# Load param
fsdp_optim_state = {}
checkpoint_files, _ = ckpt_index_file.get_checkpoint_filenames()
for shard_file in checkpoint_files:
if shard_file.endswith(".safetensors"):
state_dict_shard = load_flat(shard_file, seperator=".")
else:
state_dict_shard = utils.load_shard_state_dict(Path(shard_file), use_safetensors=False)
for state_dict_shard in utils.load_state_dict_shards(checkpoint_files, True, False):
fsdp_optim_state.update(state_dict_shard)
fsdp_optim_dict = dict(state=fsdp_optim_state, param_groups=saved_param_groups)

View File

@ -18,9 +18,9 @@ from .utils import (
get_optimizer_base_filenames,
is_safetensors_available,
load_param_groups_into_optimizer,
load_shard_state_dict,
load_state_dict,
load_state_dict_into_model,
load_state_dict_shards,
load_states_into_optimizer,
save_config_file,
save_param_groups,
@ -94,11 +94,7 @@ class GeneralCheckpointIO(CheckpointIO):
checkpoint_files, _ = ckpt_index_file.get_checkpoint_filenames()
for shard_file in checkpoint_files:
if shard_file.endswith(".safetensors"):
state_dict = load_flat(shard_file)
else:
state_dict = load_shard_state_dict(Path(shard_file), use_safetensors=False)
for state_dict in load_state_dict_shards(checkpoint_files, True, False, low_cpu_mem_mode):
if not low_cpu_mem_mode:
state_dict = create_pinned_state_dict(state_dict, empty=False, num_threads=num_threads)
load_states_into_optimizer(optimizer, state_dict, id_map)
@ -295,8 +291,7 @@ class GeneralCheckpointIO(CheckpointIO):
checkpoint_files, _ = ckpt_index_file.get_checkpoint_filenames()
missing_keys = []
for shard_file in checkpoint_files:
state_dict = load_shard_state_dict(Path(shard_file), use_safetensors)
for state_dict in load_state_dict_shards(checkpoint_files, False, use_safetensors, low_cpu_mem_mode):
if not low_cpu_mem_mode:
state_dict = create_pinned_state_dict(state_dict, empty=False, num_threads=num_threads)
load_state_dict_into_model(model, state_dict, missing_keys, strict, load_sub_module)

View File

@ -6,7 +6,7 @@ from collections import abc as container_abcs
from collections import defaultdict
from itertools import chain
from pathlib import Path
from typing import Dict, Iterator, List, Mapping, Optional, OrderedDict, Tuple, Union
from typing import Dict, Generator, Iterator, List, Mapping, Optional, OrderedDict, Tuple, Union
import torch
import torch.nn as nn
@ -21,7 +21,7 @@ from colossalai.tensor.d_tensor import (
to_global,
to_global_for_customized_distributed_tensor,
)
from colossalai.utils.safetensors import _flatten_optim_state_dict
from colossalai.utils.safetensors import _flatten_optim_state_dict, load_flat
SAFE_WEIGHTS_NAME = "model.safetensors"
WEIGHTS_NAME = "pytorch_model.bin"
@ -972,3 +972,35 @@ def create_pinned_state_dict(
idx = future_to_idx[future]
elems[idx] = future.result()
return tree_unflatten(elems, spec)
def load_optim_or_model_shard(path: str, is_optim: bool, use_safetensors: bool) -> dict:
if is_optim:
if path.endswith(".safetensors"):
state_dict = load_flat(path)
else:
state_dict = load_shard_state_dict(Path(path), use_safetensors=False)
else:
state_dict = load_shard_state_dict(Path(path), use_safetensors)
return state_dict
def load_state_dict_shards(
checkpoint_files: List[str],
is_optim: bool,
use_safetensors: bool,
low_cpu_mem_mode: bool = True,
prefetch: int = 3,
) -> Generator[dict, None, None]:
if low_cpu_mem_mode:
for shard_file in checkpoint_files:
state_dict = load_optim_or_model_shard(shard_file, is_optim, use_safetensors)
yield state_dict
else:
with concurrent.futures.ThreadPoolExecutor(max_workers=prefetch) as executor:
futures = []
for shard_file in checkpoint_files:
future = executor.submit(load_optim_or_model_shard, shard_file, is_optim, use_safetensors)
futures.append(future)
for future in concurrent.futures.as_completed(futures):
yield future.result()

10
tests/conftest.py Normal file
View File

@ -0,0 +1,10 @@
import gc
from colossalai.accelerator import get_accelerator
def pytest_runtest_setup(item):
# called for running each test in 'a' directory
accelerator = get_accelerator()
accelerator.empty_cache()
gc.collect()