mirror of https://github.com/hpcaitech/ColossalAI
[checkpointio] support load-pin overlap (#6177)
* [checkpointio] support load-pin overlap * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [test] add conftest --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>pull/6168/head
parent
479067e9bc
commit
ee81366cac
|
@ -20,7 +20,7 @@ from colossalai.checkpoint_io.utils import (
|
||||||
create_pinned_state_dict,
|
create_pinned_state_dict,
|
||||||
get_model_base_filenames,
|
get_model_base_filenames,
|
||||||
get_optimizer_base_filenames,
|
get_optimizer_base_filenames,
|
||||||
load_shard_state_dict,
|
load_state_dict_shards,
|
||||||
save_config_file,
|
save_config_file,
|
||||||
save_state_dict,
|
save_state_dict,
|
||||||
save_state_dict_shards,
|
save_state_dict_shards,
|
||||||
|
@ -29,7 +29,6 @@ from colossalai.cluster import DistCoordinator, ProcessGroupMesh
|
||||||
from colossalai.interface import ModelWrapper, OptimizerWrapper
|
from colossalai.interface import ModelWrapper, OptimizerWrapper
|
||||||
from colossalai.logging import get_dist_logger
|
from colossalai.logging import get_dist_logger
|
||||||
from colossalai.shardformer import ShardConfig, ShardFormer
|
from colossalai.shardformer import ShardConfig, ShardFormer
|
||||||
from colossalai.utils.safetensors import load_flat
|
|
||||||
from colossalai.zero import GeminiDDP, GeminiOptimizer
|
from colossalai.zero import GeminiDDP, GeminiOptimizer
|
||||||
from colossalai.zero.gemini.memory_tracer import MemStats
|
from colossalai.zero.gemini.memory_tracer import MemStats
|
||||||
|
|
||||||
|
@ -350,11 +349,9 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
|
||||||
|
|
||||||
# Load optimizer states from shard files under checkpoint path.
|
# Load optimizer states from shard files under checkpoint path.
|
||||||
# For each file, only load the states managed by current process.
|
# For each file, only load the states managed by current process.
|
||||||
for shard_file in checkpoint_files:
|
for state_dict_shard in load_state_dict_shards(
|
||||||
if shard_file.endswith(".safetensors"):
|
checkpoint_files, True, False, low_cpu_mem_mode=low_cpu_mem_mode
|
||||||
state_dict_shard = load_flat(shard_file)
|
):
|
||||||
else:
|
|
||||||
state_dict_shard = load_shard_state_dict(Path(shard_file), use_safetensors=False)
|
|
||||||
if not low_cpu_mem_mode:
|
if not low_cpu_mem_mode:
|
||||||
state_dict_shard = create_pinned_state_dict(state_dict_shard, empty=False, num_threads=num_threads)
|
state_dict_shard = create_pinned_state_dict(state_dict_shard, empty=False, num_threads=num_threads)
|
||||||
optimizer.load_param_states(state_dict_shard)
|
optimizer.load_param_states(state_dict_shard)
|
||||||
|
|
|
@ -24,8 +24,8 @@ from colossalai.checkpoint_io.utils import (
|
||||||
get_optimizer_base_filenames,
|
get_optimizer_base_filenames,
|
||||||
get_shard_filename,
|
get_shard_filename,
|
||||||
load_param_groups_into_optimizer,
|
load_param_groups_into_optimizer,
|
||||||
load_shard_state_dict,
|
|
||||||
load_state_dict,
|
load_state_dict,
|
||||||
|
load_state_dict_shards,
|
||||||
load_states_into_optimizer,
|
load_states_into_optimizer,
|
||||||
save_param_groups,
|
save_param_groups,
|
||||||
save_state_dict,
|
save_state_dict,
|
||||||
|
@ -276,13 +276,7 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
|
||||||
|
|
||||||
checkpoint_files, _ = ckpt_index_file.get_checkpoint_filenames()
|
checkpoint_files, _ = ckpt_index_file.get_checkpoint_filenames()
|
||||||
|
|
||||||
for shard_file in checkpoint_files:
|
for state_dict in load_state_dict_shards(checkpoint_files, True, False, low_cpu_mem_mode):
|
||||||
if shard_file.endswith(".safetensors"):
|
|
||||||
from colossalai.utils.safetensors import load_flat
|
|
||||||
|
|
||||||
state_dict = load_flat(shard_file)
|
|
||||||
else:
|
|
||||||
state_dict = load_shard_state_dict(Path(shard_file), use_safetensors=False)
|
|
||||||
# shard state dict
|
# shard state dict
|
||||||
for param_idx, state in state_dict.items():
|
for param_idx, state in state_dict.items():
|
||||||
for k, v in state.items():
|
for k, v in state.items():
|
||||||
|
|
|
@ -255,8 +255,8 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO):
|
||||||
checkpoint_files, _ = ckpt_index_file.get_checkpoint_filenames()
|
checkpoint_files, _ = ckpt_index_file.get_checkpoint_filenames()
|
||||||
|
|
||||||
fsdp_state_dict = {}
|
fsdp_state_dict = {}
|
||||||
for shard_file in checkpoint_files:
|
for state_dict in utils.load_state_dict_shards(checkpoint_files, False, use_safetensors):
|
||||||
fsdp_state_dict.update(utils.load_shard_state_dict(Path(shard_file), use_safetensors))
|
fsdp_state_dict.update(state_dict)
|
||||||
|
|
||||||
with FSDP.state_dict_type(model.unwrap(), StateDictType.FULL_STATE_DICT):
|
with FSDP.state_dict_type(model.unwrap(), StateDictType.FULL_STATE_DICT):
|
||||||
model.unwrap().load_state_dict(fsdp_state_dict, strict=False)
|
model.unwrap().load_state_dict(fsdp_state_dict, strict=False)
|
||||||
|
@ -388,11 +388,7 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO):
|
||||||
# Load param
|
# Load param
|
||||||
fsdp_optim_state = {}
|
fsdp_optim_state = {}
|
||||||
checkpoint_files, _ = ckpt_index_file.get_checkpoint_filenames()
|
checkpoint_files, _ = ckpt_index_file.get_checkpoint_filenames()
|
||||||
for shard_file in checkpoint_files:
|
for state_dict_shard in utils.load_state_dict_shards(checkpoint_files, True, False):
|
||||||
if shard_file.endswith(".safetensors"):
|
|
||||||
state_dict_shard = load_flat(shard_file, seperator=".")
|
|
||||||
else:
|
|
||||||
state_dict_shard = utils.load_shard_state_dict(Path(shard_file), use_safetensors=False)
|
|
||||||
fsdp_optim_state.update(state_dict_shard)
|
fsdp_optim_state.update(state_dict_shard)
|
||||||
|
|
||||||
fsdp_optim_dict = dict(state=fsdp_optim_state, param_groups=saved_param_groups)
|
fsdp_optim_dict = dict(state=fsdp_optim_state, param_groups=saved_param_groups)
|
||||||
|
|
|
@ -18,9 +18,9 @@ from .utils import (
|
||||||
get_optimizer_base_filenames,
|
get_optimizer_base_filenames,
|
||||||
is_safetensors_available,
|
is_safetensors_available,
|
||||||
load_param_groups_into_optimizer,
|
load_param_groups_into_optimizer,
|
||||||
load_shard_state_dict,
|
|
||||||
load_state_dict,
|
load_state_dict,
|
||||||
load_state_dict_into_model,
|
load_state_dict_into_model,
|
||||||
|
load_state_dict_shards,
|
||||||
load_states_into_optimizer,
|
load_states_into_optimizer,
|
||||||
save_config_file,
|
save_config_file,
|
||||||
save_param_groups,
|
save_param_groups,
|
||||||
|
@ -94,11 +94,7 @@ class GeneralCheckpointIO(CheckpointIO):
|
||||||
|
|
||||||
checkpoint_files, _ = ckpt_index_file.get_checkpoint_filenames()
|
checkpoint_files, _ = ckpt_index_file.get_checkpoint_filenames()
|
||||||
|
|
||||||
for shard_file in checkpoint_files:
|
for state_dict in load_state_dict_shards(checkpoint_files, True, False, low_cpu_mem_mode):
|
||||||
if shard_file.endswith(".safetensors"):
|
|
||||||
state_dict = load_flat(shard_file)
|
|
||||||
else:
|
|
||||||
state_dict = load_shard_state_dict(Path(shard_file), use_safetensors=False)
|
|
||||||
if not low_cpu_mem_mode:
|
if not low_cpu_mem_mode:
|
||||||
state_dict = create_pinned_state_dict(state_dict, empty=False, num_threads=num_threads)
|
state_dict = create_pinned_state_dict(state_dict, empty=False, num_threads=num_threads)
|
||||||
load_states_into_optimizer(optimizer, state_dict, id_map)
|
load_states_into_optimizer(optimizer, state_dict, id_map)
|
||||||
|
@ -295,8 +291,7 @@ class GeneralCheckpointIO(CheckpointIO):
|
||||||
checkpoint_files, _ = ckpt_index_file.get_checkpoint_filenames()
|
checkpoint_files, _ = ckpt_index_file.get_checkpoint_filenames()
|
||||||
missing_keys = []
|
missing_keys = []
|
||||||
|
|
||||||
for shard_file in checkpoint_files:
|
for state_dict in load_state_dict_shards(checkpoint_files, False, use_safetensors, low_cpu_mem_mode):
|
||||||
state_dict = load_shard_state_dict(Path(shard_file), use_safetensors)
|
|
||||||
if not low_cpu_mem_mode:
|
if not low_cpu_mem_mode:
|
||||||
state_dict = create_pinned_state_dict(state_dict, empty=False, num_threads=num_threads)
|
state_dict = create_pinned_state_dict(state_dict, empty=False, num_threads=num_threads)
|
||||||
load_state_dict_into_model(model, state_dict, missing_keys, strict, load_sub_module)
|
load_state_dict_into_model(model, state_dict, missing_keys, strict, load_sub_module)
|
||||||
|
|
|
@ -6,7 +6,7 @@ from collections import abc as container_abcs
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from itertools import chain
|
from itertools import chain
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, Iterator, List, Mapping, Optional, OrderedDict, Tuple, Union
|
from typing import Dict, Generator, Iterator, List, Mapping, Optional, OrderedDict, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
@ -21,7 +21,7 @@ from colossalai.tensor.d_tensor import (
|
||||||
to_global,
|
to_global,
|
||||||
to_global_for_customized_distributed_tensor,
|
to_global_for_customized_distributed_tensor,
|
||||||
)
|
)
|
||||||
from colossalai.utils.safetensors import _flatten_optim_state_dict
|
from colossalai.utils.safetensors import _flatten_optim_state_dict, load_flat
|
||||||
|
|
||||||
SAFE_WEIGHTS_NAME = "model.safetensors"
|
SAFE_WEIGHTS_NAME = "model.safetensors"
|
||||||
WEIGHTS_NAME = "pytorch_model.bin"
|
WEIGHTS_NAME = "pytorch_model.bin"
|
||||||
|
@ -972,3 +972,35 @@ def create_pinned_state_dict(
|
||||||
idx = future_to_idx[future]
|
idx = future_to_idx[future]
|
||||||
elems[idx] = future.result()
|
elems[idx] = future.result()
|
||||||
return tree_unflatten(elems, spec)
|
return tree_unflatten(elems, spec)
|
||||||
|
|
||||||
|
|
||||||
|
def load_optim_or_model_shard(path: str, is_optim: bool, use_safetensors: bool) -> dict:
|
||||||
|
if is_optim:
|
||||||
|
if path.endswith(".safetensors"):
|
||||||
|
state_dict = load_flat(path)
|
||||||
|
else:
|
||||||
|
state_dict = load_shard_state_dict(Path(path), use_safetensors=False)
|
||||||
|
else:
|
||||||
|
state_dict = load_shard_state_dict(Path(path), use_safetensors)
|
||||||
|
return state_dict
|
||||||
|
|
||||||
|
|
||||||
|
def load_state_dict_shards(
|
||||||
|
checkpoint_files: List[str],
|
||||||
|
is_optim: bool,
|
||||||
|
use_safetensors: bool,
|
||||||
|
low_cpu_mem_mode: bool = True,
|
||||||
|
prefetch: int = 3,
|
||||||
|
) -> Generator[dict, None, None]:
|
||||||
|
if low_cpu_mem_mode:
|
||||||
|
for shard_file in checkpoint_files:
|
||||||
|
state_dict = load_optim_or_model_shard(shard_file, is_optim, use_safetensors)
|
||||||
|
yield state_dict
|
||||||
|
else:
|
||||||
|
with concurrent.futures.ThreadPoolExecutor(max_workers=prefetch) as executor:
|
||||||
|
futures = []
|
||||||
|
for shard_file in checkpoint_files:
|
||||||
|
future = executor.submit(load_optim_or_model_shard, shard_file, is_optim, use_safetensors)
|
||||||
|
futures.append(future)
|
||||||
|
for future in concurrent.futures.as_completed(futures):
|
||||||
|
yield future.result()
|
||||||
|
|
|
@ -0,0 +1,10 @@
|
||||||
|
import gc
|
||||||
|
|
||||||
|
from colossalai.accelerator import get_accelerator
|
||||||
|
|
||||||
|
|
||||||
|
def pytest_runtest_setup(item):
|
||||||
|
# called for running each test in 'a' directory
|
||||||
|
accelerator = get_accelerator()
|
||||||
|
accelerator.empty_cache()
|
||||||
|
gc.collect()
|
Loading…
Reference in New Issue