[checkpointio] support load-pin overlap (#6177)

* [checkpointio] support load-pin overlap

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* [test] add conftest

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
pull/6168/head
Hongxin Liu 2025-01-07 16:16:04 +08:00 committed by GitHub
parent 479067e9bc
commit ee81366cac
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 56 additions and 32 deletions

View File

@ -20,7 +20,7 @@ from colossalai.checkpoint_io.utils import (
create_pinned_state_dict, create_pinned_state_dict,
get_model_base_filenames, get_model_base_filenames,
get_optimizer_base_filenames, get_optimizer_base_filenames,
load_shard_state_dict, load_state_dict_shards,
save_config_file, save_config_file,
save_state_dict, save_state_dict,
save_state_dict_shards, save_state_dict_shards,
@ -29,7 +29,6 @@ from colossalai.cluster import DistCoordinator, ProcessGroupMesh
from colossalai.interface import ModelWrapper, OptimizerWrapper from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.shardformer import ShardConfig, ShardFormer from colossalai.shardformer import ShardConfig, ShardFormer
from colossalai.utils.safetensors import load_flat
from colossalai.zero import GeminiDDP, GeminiOptimizer from colossalai.zero import GeminiDDP, GeminiOptimizer
from colossalai.zero.gemini.memory_tracer import MemStats from colossalai.zero.gemini.memory_tracer import MemStats
@ -350,11 +349,9 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
# Load optimizer states from shard files under checkpoint path. # Load optimizer states from shard files under checkpoint path.
# For each file, only load the states managed by current process. # For each file, only load the states managed by current process.
for shard_file in checkpoint_files: for state_dict_shard in load_state_dict_shards(
if shard_file.endswith(".safetensors"): checkpoint_files, True, False, low_cpu_mem_mode=low_cpu_mem_mode
state_dict_shard = load_flat(shard_file) ):
else:
state_dict_shard = load_shard_state_dict(Path(shard_file), use_safetensors=False)
if not low_cpu_mem_mode: if not low_cpu_mem_mode:
state_dict_shard = create_pinned_state_dict(state_dict_shard, empty=False, num_threads=num_threads) state_dict_shard = create_pinned_state_dict(state_dict_shard, empty=False, num_threads=num_threads)
optimizer.load_param_states(state_dict_shard) optimizer.load_param_states(state_dict_shard)

View File

@ -24,8 +24,8 @@ from colossalai.checkpoint_io.utils import (
get_optimizer_base_filenames, get_optimizer_base_filenames,
get_shard_filename, get_shard_filename,
load_param_groups_into_optimizer, load_param_groups_into_optimizer,
load_shard_state_dict,
load_state_dict, load_state_dict,
load_state_dict_shards,
load_states_into_optimizer, load_states_into_optimizer,
save_param_groups, save_param_groups,
save_state_dict, save_state_dict,
@ -276,13 +276,7 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
checkpoint_files, _ = ckpt_index_file.get_checkpoint_filenames() checkpoint_files, _ = ckpt_index_file.get_checkpoint_filenames()
for shard_file in checkpoint_files: for state_dict in load_state_dict_shards(checkpoint_files, True, False, low_cpu_mem_mode):
if shard_file.endswith(".safetensors"):
from colossalai.utils.safetensors import load_flat
state_dict = load_flat(shard_file)
else:
state_dict = load_shard_state_dict(Path(shard_file), use_safetensors=False)
# shard state dict # shard state dict
for param_idx, state in state_dict.items(): for param_idx, state in state_dict.items():
for k, v in state.items(): for k, v in state.items():

View File

@ -255,8 +255,8 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO):
checkpoint_files, _ = ckpt_index_file.get_checkpoint_filenames() checkpoint_files, _ = ckpt_index_file.get_checkpoint_filenames()
fsdp_state_dict = {} fsdp_state_dict = {}
for shard_file in checkpoint_files: for state_dict in utils.load_state_dict_shards(checkpoint_files, False, use_safetensors):
fsdp_state_dict.update(utils.load_shard_state_dict(Path(shard_file), use_safetensors)) fsdp_state_dict.update(state_dict)
with FSDP.state_dict_type(model.unwrap(), StateDictType.FULL_STATE_DICT): with FSDP.state_dict_type(model.unwrap(), StateDictType.FULL_STATE_DICT):
model.unwrap().load_state_dict(fsdp_state_dict, strict=False) model.unwrap().load_state_dict(fsdp_state_dict, strict=False)
@ -388,11 +388,7 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO):
# Load param # Load param
fsdp_optim_state = {} fsdp_optim_state = {}
checkpoint_files, _ = ckpt_index_file.get_checkpoint_filenames() checkpoint_files, _ = ckpt_index_file.get_checkpoint_filenames()
for shard_file in checkpoint_files: for state_dict_shard in utils.load_state_dict_shards(checkpoint_files, True, False):
if shard_file.endswith(".safetensors"):
state_dict_shard = load_flat(shard_file, seperator=".")
else:
state_dict_shard = utils.load_shard_state_dict(Path(shard_file), use_safetensors=False)
fsdp_optim_state.update(state_dict_shard) fsdp_optim_state.update(state_dict_shard)
fsdp_optim_dict = dict(state=fsdp_optim_state, param_groups=saved_param_groups) fsdp_optim_dict = dict(state=fsdp_optim_state, param_groups=saved_param_groups)

View File

@ -18,9 +18,9 @@ from .utils import (
get_optimizer_base_filenames, get_optimizer_base_filenames,
is_safetensors_available, is_safetensors_available,
load_param_groups_into_optimizer, load_param_groups_into_optimizer,
load_shard_state_dict,
load_state_dict, load_state_dict,
load_state_dict_into_model, load_state_dict_into_model,
load_state_dict_shards,
load_states_into_optimizer, load_states_into_optimizer,
save_config_file, save_config_file,
save_param_groups, save_param_groups,
@ -94,11 +94,7 @@ class GeneralCheckpointIO(CheckpointIO):
checkpoint_files, _ = ckpt_index_file.get_checkpoint_filenames() checkpoint_files, _ = ckpt_index_file.get_checkpoint_filenames()
for shard_file in checkpoint_files: for state_dict in load_state_dict_shards(checkpoint_files, True, False, low_cpu_mem_mode):
if shard_file.endswith(".safetensors"):
state_dict = load_flat(shard_file)
else:
state_dict = load_shard_state_dict(Path(shard_file), use_safetensors=False)
if not low_cpu_mem_mode: if not low_cpu_mem_mode:
state_dict = create_pinned_state_dict(state_dict, empty=False, num_threads=num_threads) state_dict = create_pinned_state_dict(state_dict, empty=False, num_threads=num_threads)
load_states_into_optimizer(optimizer, state_dict, id_map) load_states_into_optimizer(optimizer, state_dict, id_map)
@ -295,8 +291,7 @@ class GeneralCheckpointIO(CheckpointIO):
checkpoint_files, _ = ckpt_index_file.get_checkpoint_filenames() checkpoint_files, _ = ckpt_index_file.get_checkpoint_filenames()
missing_keys = [] missing_keys = []
for shard_file in checkpoint_files: for state_dict in load_state_dict_shards(checkpoint_files, False, use_safetensors, low_cpu_mem_mode):
state_dict = load_shard_state_dict(Path(shard_file), use_safetensors)
if not low_cpu_mem_mode: if not low_cpu_mem_mode:
state_dict = create_pinned_state_dict(state_dict, empty=False, num_threads=num_threads) state_dict = create_pinned_state_dict(state_dict, empty=False, num_threads=num_threads)
load_state_dict_into_model(model, state_dict, missing_keys, strict, load_sub_module) load_state_dict_into_model(model, state_dict, missing_keys, strict, load_sub_module)

View File

@ -6,7 +6,7 @@ from collections import abc as container_abcs
from collections import defaultdict from collections import defaultdict
from itertools import chain from itertools import chain
from pathlib import Path from pathlib import Path
from typing import Dict, Iterator, List, Mapping, Optional, OrderedDict, Tuple, Union from typing import Dict, Generator, Iterator, List, Mapping, Optional, OrderedDict, Tuple, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -21,7 +21,7 @@ from colossalai.tensor.d_tensor import (
to_global, to_global,
to_global_for_customized_distributed_tensor, to_global_for_customized_distributed_tensor,
) )
from colossalai.utils.safetensors import _flatten_optim_state_dict from colossalai.utils.safetensors import _flatten_optim_state_dict, load_flat
SAFE_WEIGHTS_NAME = "model.safetensors" SAFE_WEIGHTS_NAME = "model.safetensors"
WEIGHTS_NAME = "pytorch_model.bin" WEIGHTS_NAME = "pytorch_model.bin"
@ -972,3 +972,35 @@ def create_pinned_state_dict(
idx = future_to_idx[future] idx = future_to_idx[future]
elems[idx] = future.result() elems[idx] = future.result()
return tree_unflatten(elems, spec) return tree_unflatten(elems, spec)
def load_optim_or_model_shard(path: str, is_optim: bool, use_safetensors: bool) -> dict:
if is_optim:
if path.endswith(".safetensors"):
state_dict = load_flat(path)
else:
state_dict = load_shard_state_dict(Path(path), use_safetensors=False)
else:
state_dict = load_shard_state_dict(Path(path), use_safetensors)
return state_dict
def load_state_dict_shards(
checkpoint_files: List[str],
is_optim: bool,
use_safetensors: bool,
low_cpu_mem_mode: bool = True,
prefetch: int = 3,
) -> Generator[dict, None, None]:
if low_cpu_mem_mode:
for shard_file in checkpoint_files:
state_dict = load_optim_or_model_shard(shard_file, is_optim, use_safetensors)
yield state_dict
else:
with concurrent.futures.ThreadPoolExecutor(max_workers=prefetch) as executor:
futures = []
for shard_file in checkpoint_files:
future = executor.submit(load_optim_or_model_shard, shard_file, is_optim, use_safetensors)
futures.append(future)
for future in concurrent.futures.as_completed(futures):
yield future.result()

10
tests/conftest.py Normal file
View File

@ -0,0 +1,10 @@
import gc
from colossalai.accelerator import get_accelerator
def pytest_runtest_setup(item):
# called for running each test in 'a' directory
accelerator = get_accelerator()
accelerator.empty_cache()
gc.collect()