From ee81366cac775dfb072bccec1cfb6b424722e2c2 Mon Sep 17 00:00:00 2001 From: Hongxin Liu Date: Tue, 7 Jan 2025 16:16:04 +0800 Subject: [PATCH] [checkpointio] support load-pin overlap (#6177) * [checkpointio] support load-pin overlap * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [test] add conftest --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- colossalai/booster/plugin/gemini_plugin.py | 11 +++--- .../booster/plugin/low_level_zero_plugin.py | 10 ++---- .../booster/plugin/torch_fsdp_plugin.py | 10 ++---- .../checkpoint_io/general_checkpoint_io.py | 11 ++---- colossalai/checkpoint_io/utils.py | 36 +++++++++++++++++-- tests/conftest.py | 10 ++++++ 6 files changed, 56 insertions(+), 32 deletions(-) create mode 100644 tests/conftest.py diff --git a/colossalai/booster/plugin/gemini_plugin.py b/colossalai/booster/plugin/gemini_plugin.py index ba43a5066..4b1224c68 100644 --- a/colossalai/booster/plugin/gemini_plugin.py +++ b/colossalai/booster/plugin/gemini_plugin.py @@ -20,7 +20,7 @@ from colossalai.checkpoint_io.utils import ( create_pinned_state_dict, get_model_base_filenames, get_optimizer_base_filenames, - load_shard_state_dict, + load_state_dict_shards, save_config_file, save_state_dict, save_state_dict_shards, @@ -29,7 +29,6 @@ from colossalai.cluster import DistCoordinator, ProcessGroupMesh from colossalai.interface import ModelWrapper, OptimizerWrapper from colossalai.logging import get_dist_logger from colossalai.shardformer import ShardConfig, ShardFormer -from colossalai.utils.safetensors import load_flat from colossalai.zero import GeminiDDP, GeminiOptimizer from colossalai.zero.gemini.memory_tracer import MemStats @@ -350,11 +349,9 @@ class GeminiCheckpointIO(GeneralCheckpointIO): # Load optimizer states from shard files under checkpoint path. # For each file, only load the states managed by current process. - for shard_file in checkpoint_files: - if shard_file.endswith(".safetensors"): - state_dict_shard = load_flat(shard_file) - else: - state_dict_shard = load_shard_state_dict(Path(shard_file), use_safetensors=False) + for state_dict_shard in load_state_dict_shards( + checkpoint_files, True, False, low_cpu_mem_mode=low_cpu_mem_mode + ): if not low_cpu_mem_mode: state_dict_shard = create_pinned_state_dict(state_dict_shard, empty=False, num_threads=num_threads) optimizer.load_param_states(state_dict_shard) diff --git a/colossalai/booster/plugin/low_level_zero_plugin.py b/colossalai/booster/plugin/low_level_zero_plugin.py index 0bb4ae9ed..d29098a6e 100644 --- a/colossalai/booster/plugin/low_level_zero_plugin.py +++ b/colossalai/booster/plugin/low_level_zero_plugin.py @@ -24,8 +24,8 @@ from colossalai.checkpoint_io.utils import ( get_optimizer_base_filenames, get_shard_filename, load_param_groups_into_optimizer, - load_shard_state_dict, load_state_dict, + load_state_dict_shards, load_states_into_optimizer, save_param_groups, save_state_dict, @@ -276,13 +276,7 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO): checkpoint_files, _ = ckpt_index_file.get_checkpoint_filenames() - for shard_file in checkpoint_files: - if shard_file.endswith(".safetensors"): - from colossalai.utils.safetensors import load_flat - - state_dict = load_flat(shard_file) - else: - state_dict = load_shard_state_dict(Path(shard_file), use_safetensors=False) + for state_dict in load_state_dict_shards(checkpoint_files, True, False, low_cpu_mem_mode): # shard state dict for param_idx, state in state_dict.items(): for k, v in state.items(): diff --git a/colossalai/booster/plugin/torch_fsdp_plugin.py b/colossalai/booster/plugin/torch_fsdp_plugin.py index 182fb3c7b..dca7d43c0 100644 --- a/colossalai/booster/plugin/torch_fsdp_plugin.py +++ b/colossalai/booster/plugin/torch_fsdp_plugin.py @@ -255,8 +255,8 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO): checkpoint_files, _ = ckpt_index_file.get_checkpoint_filenames() fsdp_state_dict = {} - for shard_file in checkpoint_files: - fsdp_state_dict.update(utils.load_shard_state_dict(Path(shard_file), use_safetensors)) + for state_dict in utils.load_state_dict_shards(checkpoint_files, False, use_safetensors): + fsdp_state_dict.update(state_dict) with FSDP.state_dict_type(model.unwrap(), StateDictType.FULL_STATE_DICT): model.unwrap().load_state_dict(fsdp_state_dict, strict=False) @@ -388,11 +388,7 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO): # Load param fsdp_optim_state = {} checkpoint_files, _ = ckpt_index_file.get_checkpoint_filenames() - for shard_file in checkpoint_files: - if shard_file.endswith(".safetensors"): - state_dict_shard = load_flat(shard_file, seperator=".") - else: - state_dict_shard = utils.load_shard_state_dict(Path(shard_file), use_safetensors=False) + for state_dict_shard in utils.load_state_dict_shards(checkpoint_files, True, False): fsdp_optim_state.update(state_dict_shard) fsdp_optim_dict = dict(state=fsdp_optim_state, param_groups=saved_param_groups) diff --git a/colossalai/checkpoint_io/general_checkpoint_io.py b/colossalai/checkpoint_io/general_checkpoint_io.py index 78404f908..c38958ee3 100644 --- a/colossalai/checkpoint_io/general_checkpoint_io.py +++ b/colossalai/checkpoint_io/general_checkpoint_io.py @@ -18,9 +18,9 @@ from .utils import ( get_optimizer_base_filenames, is_safetensors_available, load_param_groups_into_optimizer, - load_shard_state_dict, load_state_dict, load_state_dict_into_model, + load_state_dict_shards, load_states_into_optimizer, save_config_file, save_param_groups, @@ -94,11 +94,7 @@ class GeneralCheckpointIO(CheckpointIO): checkpoint_files, _ = ckpt_index_file.get_checkpoint_filenames() - for shard_file in checkpoint_files: - if shard_file.endswith(".safetensors"): - state_dict = load_flat(shard_file) - else: - state_dict = load_shard_state_dict(Path(shard_file), use_safetensors=False) + for state_dict in load_state_dict_shards(checkpoint_files, True, False, low_cpu_mem_mode): if not low_cpu_mem_mode: state_dict = create_pinned_state_dict(state_dict, empty=False, num_threads=num_threads) load_states_into_optimizer(optimizer, state_dict, id_map) @@ -295,8 +291,7 @@ class GeneralCheckpointIO(CheckpointIO): checkpoint_files, _ = ckpt_index_file.get_checkpoint_filenames() missing_keys = [] - for shard_file in checkpoint_files: - state_dict = load_shard_state_dict(Path(shard_file), use_safetensors) + for state_dict in load_state_dict_shards(checkpoint_files, False, use_safetensors, low_cpu_mem_mode): if not low_cpu_mem_mode: state_dict = create_pinned_state_dict(state_dict, empty=False, num_threads=num_threads) load_state_dict_into_model(model, state_dict, missing_keys, strict, load_sub_module) diff --git a/colossalai/checkpoint_io/utils.py b/colossalai/checkpoint_io/utils.py index 7b322b657..50b6f1438 100644 --- a/colossalai/checkpoint_io/utils.py +++ b/colossalai/checkpoint_io/utils.py @@ -6,7 +6,7 @@ from collections import abc as container_abcs from collections import defaultdict from itertools import chain from pathlib import Path -from typing import Dict, Iterator, List, Mapping, Optional, OrderedDict, Tuple, Union +from typing import Dict, Generator, Iterator, List, Mapping, Optional, OrderedDict, Tuple, Union import torch import torch.nn as nn @@ -21,7 +21,7 @@ from colossalai.tensor.d_tensor import ( to_global, to_global_for_customized_distributed_tensor, ) -from colossalai.utils.safetensors import _flatten_optim_state_dict +from colossalai.utils.safetensors import _flatten_optim_state_dict, load_flat SAFE_WEIGHTS_NAME = "model.safetensors" WEIGHTS_NAME = "pytorch_model.bin" @@ -972,3 +972,35 @@ def create_pinned_state_dict( idx = future_to_idx[future] elems[idx] = future.result() return tree_unflatten(elems, spec) + + +def load_optim_or_model_shard(path: str, is_optim: bool, use_safetensors: bool) -> dict: + if is_optim: + if path.endswith(".safetensors"): + state_dict = load_flat(path) + else: + state_dict = load_shard_state_dict(Path(path), use_safetensors=False) + else: + state_dict = load_shard_state_dict(Path(path), use_safetensors) + return state_dict + + +def load_state_dict_shards( + checkpoint_files: List[str], + is_optim: bool, + use_safetensors: bool, + low_cpu_mem_mode: bool = True, + prefetch: int = 3, +) -> Generator[dict, None, None]: + if low_cpu_mem_mode: + for shard_file in checkpoint_files: + state_dict = load_optim_or_model_shard(shard_file, is_optim, use_safetensors) + yield state_dict + else: + with concurrent.futures.ThreadPoolExecutor(max_workers=prefetch) as executor: + futures = [] + for shard_file in checkpoint_files: + future = executor.submit(load_optim_or_model_shard, shard_file, is_optim, use_safetensors) + futures.append(future) + for future in concurrent.futures.as_completed(futures): + yield future.result() diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 000000000..04e75ed0b --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,10 @@ +import gc + +from colossalai.accelerator import get_accelerator + + +def pytest_runtest_setup(item): + # called for running each test in 'a' directory + accelerator = get_accelerator() + accelerator.empty_cache() + gc.collect()