mirror of https://github.com/hpcaitech/ColossalAI
[checkpointio]support asyncio for 3d (#6152)
* fix * fix * fix * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update utils.py * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>pull/6164/head
parent
aaafb38851
commit
130229fdcb
|
@ -17,6 +17,8 @@ from torch.utils.data.distributed import DistributedSampler
|
|||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.checkpoint_io import CheckpointIndexFile, CheckpointIO, GeneralCheckpointIO
|
||||
from colossalai.checkpoint_io.utils import (
|
||||
async_save_state_dict_shards,
|
||||
create_pinned_state_dict,
|
||||
get_model_base_filenames,
|
||||
get_optimizer_base_filenames,
|
||||
load_shard_state_dict,
|
||||
|
@ -28,6 +30,7 @@ from colossalai.cluster import DistCoordinator, ProcessGroupMesh
|
|||
from colossalai.interface import ModelWrapper, OptimizerWrapper
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.shardformer import ShardConfig, ShardFormer
|
||||
from colossalai.utils.safetensors import load_flat
|
||||
from colossalai.zero import GeminiDDP, GeminiOptimizer
|
||||
from colossalai.zero.gemini.memory_tracer import MemStats
|
||||
|
||||
|
@ -82,7 +85,15 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
|
|||
state_dict = model.state_dict(only_rank_0=True)
|
||||
if self.coordinator.is_master():
|
||||
if use_async:
|
||||
super().save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors, use_async)
|
||||
from colossalai.utils.safetensors import save
|
||||
|
||||
if id(model) not in self.pinned_state_dicts:
|
||||
self.pinned_state_dicts[id(model)] = create_pinned_state_dict(state_dict)
|
||||
for k, v in state_dict.items():
|
||||
self.pinned_state_dicts[id(model)][k].copy_(v)
|
||||
state_dict[k] = self.pinned_state_dicts[id(model)][k]
|
||||
writer = save(checkpoint, state_dict)
|
||||
self.async_writers.append(writer)
|
||||
else:
|
||||
save_state_dict(state_dict, checkpoint, use_safetensors)
|
||||
|
||||
|
@ -106,6 +117,18 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
|
|||
assert isinstance(optimizer, GeminiOptimizer), "Please boost the optimizer before saving!"
|
||||
state_dict = optimizer.state_dict()
|
||||
if self.coordinator.is_master():
|
||||
if use_async:
|
||||
from colossalai.utils.safetensors import _flatten_optim_state_dict, save
|
||||
|
||||
flatten_state_dict, metadata = _flatten_optim_state_dict(state_dict)
|
||||
if id(optimizer) not in self.pinned_state_dicts:
|
||||
self.pinned_state_dicts[id(optimizer)] = create_pinned_state_dict(flatten_state_dict)
|
||||
for k, v in flatten_state_dict.items():
|
||||
self.pinned_state_dicts[id(optimizer)][k].copy_(v)
|
||||
flatten_state_dict[k] = self.pinned_state_dicts[id(optimizer)][k]
|
||||
writer = save(checkpoint, flatten_state_dict, metadata)
|
||||
self.async_writers.append(writer)
|
||||
else:
|
||||
save_state_dict(state_dict, checkpoint, use_safetensors=False)
|
||||
|
||||
def load_unsharded_optimizer(self, optimizer: GeminiOptimizer, checkpoint: str):
|
||||
|
@ -137,17 +160,29 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
|
|||
|
||||
Path(checkpoint_path).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
state_dict_shard = model.state_dict_shard(max_shard_size=max_shard_size, only_rank_0=True)
|
||||
if use_async and self.coordinator.is_master():
|
||||
if id(model) not in self.pinned_state_dicts:
|
||||
self.pinned_state_dicts[id(model)] = {}
|
||||
pinned_state_dicts = self.pinned_state_dicts[id(model)]
|
||||
else:
|
||||
pinned_state_dicts = None
|
||||
state_dict_shard = model.state_dict_shard(
|
||||
max_shard_size=max_shard_size, only_rank_0=True, pinned_state_dicts=pinned_state_dicts
|
||||
)
|
||||
weights_name, save_index_file = get_model_base_filenames(prefix, use_safetensors)
|
||||
index_file = CheckpointIndexFile(checkpoint_path)
|
||||
|
||||
# Save shards of optimizer states.
|
||||
is_master = self.coordinator.is_master()
|
||||
if use_async:
|
||||
super().save_sharded_model(
|
||||
model, checkpoint_path, gather_dtensor, prefix, max_shard_size, use_safetensors, use_async
|
||||
total_size, writers = async_save_state_dict_shards(
|
||||
sharded_state_dict=state_dict_shard,
|
||||
checkpoint=checkpoint_path,
|
||||
index_file=index_file,
|
||||
base_filename=weights_name,
|
||||
is_master=is_master,
|
||||
)
|
||||
|
||||
self.async_writers.extend(writers)
|
||||
else:
|
||||
total_size = save_state_dict_shards(
|
||||
sharded_state_dict=state_dict_shard,
|
||||
|
@ -201,7 +236,7 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
|
|||
Path(checkpoint).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Preparing file paths and index file.
|
||||
states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix)
|
||||
states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix, use_safetensors=use_async)
|
||||
index_file = CheckpointIndexFile(checkpoint)
|
||||
index_file.append_meta_data("param_groups", param_group_file)
|
||||
|
||||
|
@ -212,9 +247,28 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
|
|||
torch.save(param_groups, group_file_path)
|
||||
|
||||
# States are broken into shards within max_shard_size.
|
||||
state_dict_shard = optimizer.state_shard(prefix=prefix, max_shard_size=size_per_shard, only_rank_0=True)
|
||||
if use_async and self.coordinator.is_master():
|
||||
if id(optimizer) not in self.pinned_state_dicts:
|
||||
self.pinned_state_dicts[id(optimizer)] = {}
|
||||
pinned_state_dicts = self.pinned_state_dicts[id(optimizer)]
|
||||
else:
|
||||
pinned_state_dicts = None
|
||||
state_dict_shard = optimizer.state_shard(
|
||||
prefix=prefix, max_shard_size=size_per_shard, only_rank_0=True, pinned_state_dicts=pinned_state_dicts
|
||||
)
|
||||
|
||||
# Save shards of optimizer states.
|
||||
if use_async:
|
||||
total_size, writers = async_save_state_dict_shards(
|
||||
sharded_state_dict=state_dict_shard,
|
||||
checkpoint=checkpoint,
|
||||
index_file=index_file,
|
||||
base_filename=states_name,
|
||||
is_master=self.coordinator.is_master(),
|
||||
state_preprocess=True,
|
||||
)
|
||||
self.async_writers.extend(writers)
|
||||
else:
|
||||
total_size = save_state_dict_shards(
|
||||
sharded_state_dict=state_dict_shard,
|
||||
checkpoint=checkpoint,
|
||||
|
@ -264,6 +318,9 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
|
|||
# Load optimizer states from shard files under checkpoint path.
|
||||
# For each file, only load the states managed by current process.
|
||||
for shard_file in checkpoint_files:
|
||||
if shard_file.endswith(".safetensors"):
|
||||
state_dict_shard = load_flat(shard_file)
|
||||
else:
|
||||
state_dict_shard = load_shard_state_dict(Path(shard_file), use_safetensors=False)
|
||||
optimizer.load_param_states(state_dict_shard)
|
||||
del state_dict_shard
|
||||
|
|
|
@ -1488,7 +1488,7 @@ class HybridParallelPlugin(PipelinePluginBase):
|
|||
)
|
||||
|
||||
def get_checkpoint_io(self) -> CheckpointIO:
|
||||
return HybridParallelCheckpointIO(self.dp_group, self.pp_group, self.tp_group, self.zero_stage)
|
||||
return HybridParallelCheckpointIO(self.dp_group, self.pp_group, self.tp_group, self.sp_group, self.zero_stage)
|
||||
|
||||
def no_sync(self, model: Module, optimizer: OptimizerWrapper) -> Iterator[None]:
|
||||
assert (
|
||||
|
|
|
@ -404,7 +404,13 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
|||
|
||||
def get_checkpoint_io(self) -> MoECheckpointIO:
|
||||
return MoECheckpointIO(
|
||||
self.dp_group, self.pp_group, self.tp_group, self.ep_group, self.moe_dp_group, self.zero_stage
|
||||
self.dp_group,
|
||||
self.pp_group,
|
||||
self.tp_group,
|
||||
self.sp_group,
|
||||
self.ep_group,
|
||||
self.moe_dp_group,
|
||||
self.zero_stage,
|
||||
)
|
||||
|
||||
def configure(
|
||||
|
|
|
@ -60,7 +60,7 @@ class TorchDDPCheckpointIO(GeneralCheckpointIO):
|
|||
"""
|
||||
assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before saving!"
|
||||
if self.coordinator.is_master():
|
||||
super().save_unsharded_optimizer(optimizer, checkpoint, gather_dtensor)
|
||||
super().save_unsharded_optimizer(optimizer, checkpoint, gather_dtensor, use_async=use_async)
|
||||
|
||||
def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str):
|
||||
"""
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
import os
|
||||
from pathlib import Path
|
||||
from typing import Callable, Dict, Iterable, Iterator, List, Optional, Tuple
|
||||
from typing import Any, Callable, Dict, Iterable, Iterator, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
@ -26,9 +26,11 @@ from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
|
|||
from torch.utils.data import DataLoader
|
||||
|
||||
from colossalai.checkpoint_io import CheckpointIndexFile, CheckpointIO, GeneralCheckpointIO, utils
|
||||
from colossalai.checkpoint_io.utils import async_save_state_dict_shards, create_pinned_state_dict
|
||||
from colossalai.cluster import DistCoordinator
|
||||
from colossalai.interface import ModelWrapper, OptimizerWrapper
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.utils.safetensors import load_flat
|
||||
|
||||
from .dp_plugin_base import DPPluginBase
|
||||
|
||||
|
@ -49,8 +51,36 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO):
|
|||
|
||||
def load_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: Path):
|
||||
assert isinstance(optimizer, FSDPOptimizerWrapper), "Please boost the optimizer before loading!"
|
||||
if checkpoint.endswith(".safetensors"):
|
||||
checkpoint = load_flat(checkpoint, seperator=".")
|
||||
else:
|
||||
checkpoint = utils.load_state_dict(checkpoint)
|
||||
|
||||
fsdp_model = optimizer.unwrap_model()
|
||||
full_optimizer_state = FSDP.full_optim_state_dict(fsdp_model, optim=optimizer, rank0_only=False)
|
||||
start_index = 0
|
||||
id2name = {}
|
||||
|
||||
def get_index_mapping(group: Dict[str, Any]) -> Dict[str, Any]:
|
||||
nonlocal start_index
|
||||
start_num = len(id2name)
|
||||
id2name.update({i: p for i, p in enumerate(group["params"], start_index) if i not in id2name})
|
||||
end_num = len(id2name)
|
||||
start_index += end_num - start_num
|
||||
|
||||
for g in full_optimizer_state["param_groups"]:
|
||||
get_index_mapping(g)
|
||||
|
||||
new_state = {}
|
||||
for key, value in checkpoint["state"].items():
|
||||
new_state[id2name[int(key)]] = value
|
||||
checkpoint["state"] = new_state
|
||||
for g in checkpoint["param_groups"]:
|
||||
new_group = []
|
||||
for param_id in g["params"]:
|
||||
new_group.append(id2name[param_id])
|
||||
g["params"] = new_group
|
||||
|
||||
sharded_osd = FSDP.scatter_full_optim_state_dict(checkpoint, fsdp_model)
|
||||
optimizer.load_state_dict(sharded_osd)
|
||||
|
||||
|
@ -65,7 +95,21 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO):
|
|||
cfg = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
|
||||
with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, cfg):
|
||||
full_model_state = model.state_dict()
|
||||
utils.save_state_dict(full_model_state, checkpoint_file_path=checkpoint, use_safetensors=use_safetensors)
|
||||
if self.coordinator.is_master():
|
||||
if use_async:
|
||||
from colossalai.utils.safetensors import save
|
||||
|
||||
if id(model) not in self.pinned_state_dicts:
|
||||
self.pinned_state_dicts[id(model)] = create_pinned_state_dict(full_model_state)
|
||||
for k, v in full_model_state.items():
|
||||
self.pinned_state_dicts[id(model)][k].copy_(v)
|
||||
full_model_state[k] = self.pinned_state_dicts[id(model)][k]
|
||||
writer = save(checkpoint, full_model_state)
|
||||
self.async_writers.append(writer)
|
||||
else:
|
||||
utils.save_state_dict(
|
||||
full_model_state, checkpoint_file_path=checkpoint, use_safetensors=use_safetensors
|
||||
)
|
||||
|
||||
def save_unsharded_optimizer(
|
||||
self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool, use_async: bool = False
|
||||
|
@ -75,7 +119,42 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO):
|
|||
"""
|
||||
assert isinstance(optimizer, FSDPOptimizerWrapper), "Please boost the optimizer before saving!"
|
||||
fsdp_model = optimizer.unwrap_model()
|
||||
|
||||
full_optimizer_state = FSDP.full_optim_state_dict(fsdp_model, optim=optimizer, rank0_only=True)
|
||||
|
||||
if self.coordinator.is_master():
|
||||
|
||||
# Save order indices instead of Tensors
|
||||
name2id: Dict[str, int] = {}
|
||||
start_index = 0
|
||||
|
||||
def pack_group(group: Dict[str, Any]) -> Dict[str, Any]:
|
||||
nonlocal start_index
|
||||
packed = {k: v for k, v in group.items() if k != "params"}
|
||||
name2id.update({p: i for i, p in enumerate(group["params"], start_index) if p not in name2id})
|
||||
packed["params"] = [name2id[p] for p in group["params"]]
|
||||
start_index += len(packed["params"])
|
||||
return packed
|
||||
|
||||
param_groups = [pack_group(g) for g in full_optimizer_state["param_groups"]]
|
||||
full_optimizer_state["param_groups"] = param_groups
|
||||
new_state = {}
|
||||
for key, value in full_optimizer_state["state"].items():
|
||||
new_state[name2id[key]] = value
|
||||
full_optimizer_state["state"] = new_state
|
||||
|
||||
if use_async:
|
||||
from colossalai.utils.safetensors import _flatten_optim_state_dict, save
|
||||
|
||||
flatten_state_dict, metadata = _flatten_optim_state_dict(full_optimizer_state, seperator=".")
|
||||
if id(optimizer) not in self.pinned_state_dicts:
|
||||
self.pinned_state_dicts[id(optimizer)] = create_pinned_state_dict(flatten_state_dict)
|
||||
for k, v in flatten_state_dict.items():
|
||||
self.pinned_state_dicts[id(optimizer)][k].copy_(v)
|
||||
flatten_state_dict[k] = self.pinned_state_dicts[id(optimizer)][k]
|
||||
writer = save(checkpoint, state_dict=flatten_state_dict, metadata=metadata)
|
||||
self.async_writers.append(writer)
|
||||
else:
|
||||
utils.save_state_dict(full_optimizer_state, checkpoint_file_path=checkpoint, use_safetensors=False)
|
||||
|
||||
def save_sharded_model(
|
||||
|
@ -102,12 +181,30 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO):
|
|||
):
|
||||
state_dict = model.unwrap().state_dict()
|
||||
|
||||
state_dict_shard = utils.shard_model_checkpoint(state_dict, max_shard_size=size_per_shard)
|
||||
if use_async and self.coordinator.is_master():
|
||||
if id(model) not in self.pinned_state_dicts:
|
||||
self.pinned_state_dicts[id(model)] = {}
|
||||
pinned_state_dicts = self.pinned_state_dicts[id(model)]
|
||||
else:
|
||||
pinned_state_dicts = None
|
||||
state_dict_shard = utils.shard_model_checkpoint(
|
||||
state_dict, max_shard_size=size_per_shard, pinned_state_dicts=pinned_state_dicts
|
||||
)
|
||||
|
||||
weights_name, save_index_file = utils.get_model_base_filenames(prefix, use_safetensors)
|
||||
index_file = CheckpointIndexFile(checkpoint_path)
|
||||
|
||||
# In general cases, is_master is set to True to get the right behavior.
|
||||
if use_async:
|
||||
total_size, writers = async_save_state_dict_shards(
|
||||
sharded_state_dict=state_dict_shard,
|
||||
checkpoint=checkpoint_path,
|
||||
index_file=index_file,
|
||||
base_filename=weights_name,
|
||||
is_master=self.coordinator.is_master(),
|
||||
)
|
||||
self.async_writers.extend(writers)
|
||||
else:
|
||||
total_size = utils.save_state_dict_shards(
|
||||
sharded_state_dict=state_dict_shard,
|
||||
checkpoint=checkpoint_path,
|
||||
|
@ -188,18 +285,58 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO):
|
|||
)
|
||||
|
||||
if self.coordinator.is_master():
|
||||
|
||||
# Save order indices instead of Tensors
|
||||
name2id: Dict[str, int] = {}
|
||||
start_index = 0
|
||||
|
||||
def pack_group(group: Dict[str, Any]) -> Dict[str, Any]:
|
||||
nonlocal start_index
|
||||
packed = {k: v for k, v in group.items() if k != "params"}
|
||||
name2id.update({p: i for i, p in enumerate(group["params"], start_index) if p not in name2id})
|
||||
packed["params"] = [name2id[p] for p in group["params"]]
|
||||
start_index += len(packed["params"])
|
||||
return packed
|
||||
|
||||
param_groups = [pack_group(g) for g in fsdp_optim_state["param_groups"]]
|
||||
fsdp_optim_state["param_groups"] = param_groups
|
||||
new_state = {}
|
||||
for key, value in fsdp_optim_state["state"].items():
|
||||
new_state[name2id[key]] = value
|
||||
fsdp_optim_state["state"] = new_state
|
||||
|
||||
# Preparing file paths and index file.
|
||||
states_name, save_index_file, param_group_file = utils.get_optimizer_base_filenames(prefix)
|
||||
states_name, save_index_file, param_group_file = utils.get_optimizer_base_filenames(
|
||||
prefix, use_safetensors=use_async
|
||||
)
|
||||
index_file = CheckpointIndexFile(checkpoint)
|
||||
|
||||
index_file.append_meta_data("param_groups", param_group_file)
|
||||
group_file_path = os.path.join(checkpoint, param_group_file)
|
||||
utils.save_param_groups(fsdp_optim_state, group_file_path)
|
||||
|
||||
sharded_state = utils.shard_optimizer_checkpoint(fsdp_optim_state, max_shard_size=size_per_shard)
|
||||
|
||||
if use_async:
|
||||
if id(optimizer) not in self.pinned_state_dicts:
|
||||
self.pinned_state_dicts[id(optimizer)] = {}
|
||||
pinned_state_dicts = self.pinned_state_dicts[id(optimizer)]
|
||||
else:
|
||||
pinned_state_dicts = None
|
||||
sharded_state = utils.shard_optimizer_checkpoint(
|
||||
fsdp_optim_state, max_shard_size=size_per_shard, pinned_state_dicts=pinned_state_dicts
|
||||
)
|
||||
# Save shards of optimizer states.
|
||||
# In general cases, is_master is set to True to get the right behavior.
|
||||
if use_async:
|
||||
total_size, writers = async_save_state_dict_shards(
|
||||
sharded_state_dict=sharded_state,
|
||||
checkpoint=checkpoint,
|
||||
index_file=index_file,
|
||||
base_filename=states_name,
|
||||
is_master=self.coordinator.is_master(),
|
||||
state_preprocess=True,
|
||||
)
|
||||
self.async_writers.extend(writers)
|
||||
else:
|
||||
total_size = utils.save_state_dict_shards(
|
||||
sharded_state_dict=sharded_state,
|
||||
checkpoint=checkpoint,
|
||||
|
@ -239,11 +376,39 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO):
|
|||
fsdp_optim_state = {}
|
||||
checkpoint_files, _ = ckpt_index_file.get_checkpoint_filenames()
|
||||
for shard_file in checkpoint_files:
|
||||
if shard_file.endswith(".safetensors"):
|
||||
state_dict_shard = load_flat(shard_file, seperator=".")
|
||||
else:
|
||||
state_dict_shard = utils.load_shard_state_dict(Path(shard_file), use_safetensors=False)
|
||||
fsdp_optim_state.update(state_dict_shard)
|
||||
|
||||
fsdp_optim_dict = dict(state=fsdp_optim_state, param_groups=saved_param_groups)
|
||||
|
||||
fsdp_model = optimizer.unwrap_model()
|
||||
full_optimizer_state = FSDP.full_optim_state_dict(fsdp_model.unwrap(), optim=optimizer, rank0_only=False)
|
||||
start_index = 0
|
||||
id2name = {}
|
||||
|
||||
def get_index_mapping(group: Dict[str, Any]) -> Dict[str, Any]:
|
||||
nonlocal start_index
|
||||
start_num = len(id2name)
|
||||
id2name.update({i: p for i, p in enumerate(group["params"], start_index) if i not in id2name})
|
||||
end_num = len(id2name)
|
||||
start_index += end_num - start_num
|
||||
|
||||
for g in full_optimizer_state["param_groups"]:
|
||||
get_index_mapping(g)
|
||||
|
||||
new_state = {}
|
||||
for key, value in fsdp_optim_dict["state"].items():
|
||||
new_state[id2name[int(key)]] = value
|
||||
fsdp_optim_dict["state"] = new_state
|
||||
for g in fsdp_optim_dict["param_groups"]:
|
||||
new_group = []
|
||||
for param_id in g["params"]:
|
||||
new_group.append(id2name[param_id])
|
||||
g["params"] = new_group
|
||||
|
||||
with FSDP.state_dict_type(optimizer.unwrap_model().unwrap(), StateDictType.FULL_STATE_DICT):
|
||||
fsdp_state = FSDP.optim_state_dict_to_load(
|
||||
model=optimizer.unwrap_model().unwrap(), optim=optimizer, optim_state_dict=fsdp_optim_dict
|
||||
|
|
|
@ -8,10 +8,12 @@ from typing import Optional
|
|||
import torch.nn as nn
|
||||
from torch.optim import Optimizer
|
||||
|
||||
from colossalai.utils.safetensors import load_flat
|
||||
|
||||
from .checkpoint_io_base import CheckpointIO
|
||||
from .index_file import CheckpointIndexFile
|
||||
from .utils import (
|
||||
async_save_state_dict_shards,
|
||||
async_move_save_state_dict_shards,
|
||||
create_pinned_state_dict,
|
||||
get_model_base_filenames,
|
||||
get_optimizer_base_filenames,
|
||||
|
@ -47,10 +49,6 @@ class GeneralCheckpointIO(CheckpointIO):
|
|||
):
|
||||
state_dict = model.state_dict()
|
||||
|
||||
# TODO(FrankLeeeee): add support for gather_dtensor
|
||||
if gather_dtensor:
|
||||
pass
|
||||
|
||||
if use_async:
|
||||
from colossalai.utils.safetensors import move_and_save
|
||||
|
||||
|
@ -58,7 +56,6 @@ class GeneralCheckpointIO(CheckpointIO):
|
|||
self.pinned_state_dicts[id(model)] = create_pinned_state_dict(state_dict)
|
||||
writer = move_and_save(checkpoint, state_dict, self.pinned_state_dicts[id(model)])
|
||||
self.async_writers.append(writer)
|
||||
|
||||
else:
|
||||
# save the checkpoint
|
||||
save_state_dict(state_dict, checkpoint, use_safetensors)
|
||||
|
@ -83,6 +80,9 @@ class GeneralCheckpointIO(CheckpointIO):
|
|||
checkpoint_files, _ = ckpt_index_file.get_checkpoint_filenames()
|
||||
|
||||
for shard_file in checkpoint_files:
|
||||
if shard_file.endswith(".safetensors"):
|
||||
state_dict = load_flat(shard_file)
|
||||
else:
|
||||
state_dict = load_shard_state_dict(Path(shard_file), use_safetensors=False)
|
||||
load_states_into_optimizer(optimizer, state_dict, id_map)
|
||||
|
||||
|
@ -116,7 +116,7 @@ class GeneralCheckpointIO(CheckpointIO):
|
|||
sharded_state = shard_optimizer_checkpoint(state_dict, max_shard_size=size_per_shard)
|
||||
|
||||
# Preparing file paths and index file.
|
||||
states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix)
|
||||
states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix, use_safetensors=use_async)
|
||||
index_file = CheckpointIndexFile(checkpoint)
|
||||
|
||||
# Store the information of param groups to param_group_file.
|
||||
|
@ -126,6 +126,20 @@ class GeneralCheckpointIO(CheckpointIO):
|
|||
|
||||
# Save shards of optimizer states.
|
||||
# In general cases, is_master is set to True to get the right behavior.
|
||||
if use_async:
|
||||
pinned_state_dict = self.pinned_state_dicts.get(id(optimizer), None)
|
||||
total_size, new_pinned_state_dict, writers = async_move_save_state_dict_shards(
|
||||
sharded_state_dict=sharded_state,
|
||||
checkpoint=checkpoint,
|
||||
index_file=index_file,
|
||||
base_filename=states_name,
|
||||
is_master=True,
|
||||
pinned_state_dict=pinned_state_dict,
|
||||
state_preprocess=True,
|
||||
)
|
||||
self.pinned_state_dicts[id(optimizer)] = new_pinned_state_dict
|
||||
self.async_writers.extend(writers)
|
||||
else:
|
||||
total_size = save_state_dict_shards(
|
||||
sharded_state_dict=sharded_state,
|
||||
checkpoint=checkpoint,
|
||||
|
@ -145,6 +159,9 @@ class GeneralCheckpointIO(CheckpointIO):
|
|||
)
|
||||
|
||||
def load_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path):
|
||||
if checkpoint.endswith(".safetensors"):
|
||||
checkpoint = load_flat(checkpoint)
|
||||
else:
|
||||
checkpoint = load_state_dict(checkpoint)
|
||||
optimizer.load_state_dict(checkpoint)
|
||||
|
||||
|
@ -156,7 +173,22 @@ class GeneralCheckpointIO(CheckpointIO):
|
|||
use_async: bool = False,
|
||||
):
|
||||
# TODO(FrankLeeeee): handle distributed tensors
|
||||
save_state_dict(optimizer.state_dict(), checkpoint, use_safetensors=False)
|
||||
state_dict = optimizer.state_dict()
|
||||
if use_async:
|
||||
from colossalai.utils.safetensors import _flatten_optim_state_dict, move_and_save
|
||||
|
||||
flatten_state_dict, metadata = _flatten_optim_state_dict(state_dict)
|
||||
if id(optimizer) not in self.pinned_state_dicts:
|
||||
self.pinned_state_dicts[id(optimizer)] = create_pinned_state_dict(flatten_state_dict)
|
||||
writer = move_and_save(
|
||||
path=checkpoint,
|
||||
state_dict=flatten_state_dict,
|
||||
state_dict_pinned=self.pinned_state_dicts[id(optimizer)],
|
||||
metadata=metadata,
|
||||
)
|
||||
self.async_writers.append(writer)
|
||||
else:
|
||||
save_state_dict(state_dict, checkpoint, use_safetensors=False)
|
||||
|
||||
def save_sharded_model(
|
||||
self,
|
||||
|
@ -186,7 +218,7 @@ class GeneralCheckpointIO(CheckpointIO):
|
|||
|
||||
if use_async:
|
||||
pinned_state_dict = self.pinned_state_dicts.get(id(model), None)
|
||||
total_size, new_pinned_state_dict, writers = async_save_state_dict_shards(
|
||||
total_size, new_pinned_state_dict, writers = async_move_save_state_dict_shards(
|
||||
sharded_state_dict=state_dict_shard,
|
||||
checkpoint=checkpoint_path,
|
||||
index_file=index_file,
|
||||
|
|
|
@ -22,6 +22,7 @@ from colossalai.tensor.padded_tensor import (
|
|||
to_unpadded_tensor,
|
||||
)
|
||||
from colossalai.utils import get_current_device, get_non_persistent_buffers_set
|
||||
from colossalai.utils.safetensors import _flatten_optim_state_dict, load_flat
|
||||
|
||||
from .general_checkpoint_io import GeneralCheckpointIO
|
||||
from .index_file import CheckpointIndexFile
|
||||
|
@ -69,6 +70,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
|||
dp_group: ProcessGroup,
|
||||
pp_group: ProcessGroup,
|
||||
tp_group: ProcessGroup,
|
||||
sp_group: ProcessGroup,
|
||||
zero_stage: int,
|
||||
verbose: bool = True,
|
||||
) -> None:
|
||||
|
@ -76,9 +78,11 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
|||
self.global_dp_group = dp_group
|
||||
self.pp_group = pp_group
|
||||
self.tp_group = tp_group
|
||||
self.sp_group = sp_group
|
||||
self.dp_rank = dist.get_rank(self.global_dp_group)
|
||||
self.tp_rank = dist.get_rank(self.tp_group)
|
||||
self.pp_rank = dist.get_rank(self.pp_group)
|
||||
self.sp_rank = dist.get_rank(self.sp_group)
|
||||
self.global_dp_size = dist.get_world_size(dp_group)
|
||||
self.pp_size = dist.get_world_size(pp_group)
|
||||
self.tp_size = dist.get_world_size(tp_group)
|
||||
|
@ -88,7 +92,11 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
|||
|
||||
@staticmethod
|
||||
def _model_sharder(
|
||||
model: nn.Module, prefix: str = "", keep_vars: bool = False, size_per_shard: int = 1024
|
||||
model: nn.Module,
|
||||
prefix: str = "",
|
||||
keep_vars: bool = False,
|
||||
size_per_shard: int = 1024,
|
||||
pinned_state_dicts: Optional[Dict[str, torch.Tensor]] = None,
|
||||
) -> Iterator[Tuple[OrderedDict, int]]:
|
||||
# An internel method that breaks state_dict of model into shards within limited size.
|
||||
|
||||
|
@ -102,6 +110,11 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
|||
if is_padded_tensor(param):
|
||||
param = to_unpadded_tensor(param)
|
||||
param_ = gather_distributed_param(param, keep_vars=False)
|
||||
if pinned_state_dicts is not None:
|
||||
if (prefix + name) not in pinned_state_dicts:
|
||||
pinned_state_dicts[prefix + name] = torch.empty_like(param_, pin_memory=True, device="cpu")
|
||||
pinned_state_dicts[prefix + name].copy_(param_)
|
||||
param_ = pinned_state_dicts[prefix + name]
|
||||
block, block_size = state_dict_sharder.append_param(prefix + name, param_)
|
||||
if block is not None:
|
||||
yield block, block_size
|
||||
|
@ -111,6 +124,11 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
|||
for name, buf in model.named_buffers():
|
||||
if buf is not None and name not in non_persist_buffers_set:
|
||||
buffer = buf if keep_vars else buf.detach()
|
||||
if pinned_state_dicts is not None:
|
||||
if (prefix + name) not in pinned_state_dicts:
|
||||
pinned_state_dicts[prefix + name] = torch.empty_like(param_, pin_memory=True, device="cpu")
|
||||
pinned_state_dicts[prefix + name].copy_(buffer)
|
||||
buffer = pinned_state_dicts[prefix + name]
|
||||
block, block_size = state_dict_sharder.append_param(prefix + name, buffer)
|
||||
if block is not None:
|
||||
yield block, block_size
|
||||
|
@ -122,6 +140,11 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
|||
is not torch.nn.Module.get_extra_state
|
||||
):
|
||||
extra_state = model.get_extra_state()
|
||||
if pinned_state_dicts is not None:
|
||||
if extra_state_key not in pinned_state_dicts:
|
||||
pinned_state_dicts[extra_state_key] = torch.empty_like(param_, pin_memory=True, device="cpu")
|
||||
pinned_state_dicts[extra_state_key].copy_(extra_state)
|
||||
extra_state = pinned_state_dicts[extra_state_key]
|
||||
block, block_size = state_dict_sharder.append_param(extra_state_key, extra_state)
|
||||
if block is not None:
|
||||
yield block, block_size
|
||||
|
@ -136,6 +159,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
|||
dp_group: ProcessGroup,
|
||||
tp_group: ProcessGroup,
|
||||
size_per_shard: int = 1024,
|
||||
pinned_state_dicts: Optional[Dict[int, Dict[str, torch.Tensor]]] = None,
|
||||
):
|
||||
# An internel method that breaks state_dict of optimizer into shards within limited size.
|
||||
|
||||
|
@ -153,6 +177,9 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
|||
working_param = param
|
||||
|
||||
param_id = param_info["param2id"][id(working_param)]
|
||||
if pinned_state_dicts is not None:
|
||||
if param_id not in pinned_state_dicts:
|
||||
pinned_state_dicts[param_id] = {}
|
||||
original_shape = param_info["param2shape"][id(working_param)]
|
||||
state_ = HybridParallelCheckpointIO.gather_from_sharded_optimizer_state(
|
||||
state,
|
||||
|
@ -162,6 +189,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
|||
tp_group=tp_group,
|
||||
use_zero=use_zero,
|
||||
inplace=False,
|
||||
pinned_state_dicts=pinned_state_dicts[param_id] if pinned_state_dicts is not None else None,
|
||||
)
|
||||
|
||||
block, block_size = state_dict_sharder.append_optim_state(param_id, state_)
|
||||
|
@ -216,15 +244,31 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
|||
|
||||
# Then collect the sharded parameters & buffers along tp_group.
|
||||
# Only devices with tp_rank == 0 are responsible for model saving.
|
||||
state_dict_shard = HybridParallelCheckpointIO._model_sharder(model, size_per_shard=size_per_shard)
|
||||
control_saving = self.tp_rank == 0 and self.sp_rank == 0
|
||||
if control_saving and use_async:
|
||||
if id(model) not in self.pinned_state_dicts:
|
||||
self.pinned_state_dicts[id(model)] = {}
|
||||
pinned_state_dicts = self.pinned_state_dicts[id(model)]
|
||||
else:
|
||||
pinned_state_dicts = None
|
||||
state_dict_shard = HybridParallelCheckpointIO._model_sharder(
|
||||
model, size_per_shard=size_per_shard, pinned_state_dicts=pinned_state_dicts
|
||||
)
|
||||
weights_name, save_index_file = get_model_base_filenames(prefix, use_safetensors)
|
||||
index_file = CheckpointIndexFile(checkpoint)
|
||||
control_saving = self.tp_rank == 0
|
||||
|
||||
if self.pp_size == 1:
|
||||
# When pipeline is not used, save the model shards as in general checkpointIO
|
||||
if use_async:
|
||||
super().save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors, use_async=use_async)
|
||||
total_size, writers = async_save_state_dict_shards(
|
||||
sharded_state_dict=state_dict_shard,
|
||||
checkpoint=checkpoint,
|
||||
index_file=index_file,
|
||||
base_filename=weights_name,
|
||||
is_master=control_saving,
|
||||
state_preprocess=False,
|
||||
)
|
||||
self.async_writers.extend(writers)
|
||||
else:
|
||||
total_size = save_state_dict_shards(
|
||||
sharded_state_dict=state_dict_shard,
|
||||
|
@ -259,15 +303,16 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
|||
save_index_file = save_index_file.replace(".json", f"-stage-{self.pp_rank+1:05d}.json")
|
||||
save_index_file = os.path.join("tmp_index_files", save_index_file)
|
||||
if use_async:
|
||||
total_size, returned_state_dict, writers = async_save_state_dict_shards(
|
||||
total_size, writers = async_save_state_dict_shards(
|
||||
sharded_state_dict=state_dict_shard,
|
||||
checkpoint=checkpoint,
|
||||
index_file=index_file,
|
||||
base_filename=weights_name,
|
||||
is_master=control_saving,
|
||||
use_pp_format=True,
|
||||
n_write_entries=191,
|
||||
state_preprocess=False,
|
||||
)
|
||||
self.async_writers.extend(writers)
|
||||
else:
|
||||
total_size = save_state_dict_shards(
|
||||
sharded_state_dict=state_dict_shard,
|
||||
checkpoint=checkpoint,
|
||||
|
@ -448,19 +493,39 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
|||
|
||||
# Then collect the sharded states along dp_group(if using zero)/tp_group.
|
||||
# Only devices with (dp_rank == 0 and tp_rank == 0) are responsible for states saving.
|
||||
control_saving = self.dp_rank == 0 and self.tp_rank == 0 and self.sp_rank == 0
|
||||
|
||||
if use_async and control_saving:
|
||||
if id(optimizer) not in self.pinned_state_dicts:
|
||||
self.pinned_state_dicts[id(optimizer)] = {}
|
||||
pinned_state_dicts = self.pinned_state_dicts[id(optimizer)]
|
||||
else:
|
||||
pinned_state_dicts = None
|
||||
state_dict_shard = HybridParallelCheckpointIO._optimizer_sharder(
|
||||
optimizer,
|
||||
use_zero=self.use_zero,
|
||||
dp_group=self.global_dp_group,
|
||||
tp_group=self.tp_group,
|
||||
size_per_shard=size_per_shard,
|
||||
pinned_state_dicts=pinned_state_dicts,
|
||||
)
|
||||
states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix)
|
||||
states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix, use_safetensors=use_async)
|
||||
index_file = CheckpointIndexFile(checkpoint)
|
||||
control_saving = self.dp_rank == 0 and self.tp_rank == 0
|
||||
|
||||
if self.pp_size == 1:
|
||||
# When pipeline is not used, save the optimizer shards as in general checkpointIO
|
||||
if use_async:
|
||||
total_size, writers = async_save_state_dict_shards(
|
||||
sharded_state_dict=state_dict_shard,
|
||||
checkpoint=checkpoint,
|
||||
index_file=index_file,
|
||||
base_filename=states_name,
|
||||
is_master=control_saving,
|
||||
use_pp_format=True,
|
||||
state_preprocess=True,
|
||||
)
|
||||
self.async_writers.extend(writers)
|
||||
else:
|
||||
total_size = save_state_dict_shards(
|
||||
sharded_state_dict=state_dict_shard,
|
||||
checkpoint=checkpoint,
|
||||
|
@ -498,10 +563,25 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
|||
Path(tmp_index_file_folder).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Manage filenames of sharded weights and index file for each pipeline stage.
|
||||
if not use_async:
|
||||
states_name = states_name.replace(".bin", f"-stage-{self.pp_rank+1:05d}-shard.bin")
|
||||
else:
|
||||
states_name = states_name.replace(".safetensors", f"-stage-{self.pp_rank+1:05d}-shard.safetensors")
|
||||
save_index_file = save_index_file.replace(".json", f"-stage-{self.pp_rank+1:05d}.json")
|
||||
save_index_file = os.path.join("tmp_index_files", save_index_file)
|
||||
|
||||
if use_async:
|
||||
total_size, writers = async_save_state_dict_shards(
|
||||
sharded_state_dict=state_dict_shard,
|
||||
checkpoint=checkpoint,
|
||||
index_file=index_file,
|
||||
base_filename=states_name,
|
||||
is_master=control_saving,
|
||||
use_pp_format=True,
|
||||
state_preprocess=True,
|
||||
)
|
||||
self.async_writers.extend(writers)
|
||||
else:
|
||||
total_size = save_state_dict_shards(
|
||||
sharded_state_dict=state_dict_shard,
|
||||
checkpoint=checkpoint,
|
||||
|
@ -622,6 +702,9 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
|||
continue
|
||||
|
||||
file_path = os.path.join(ckpt_root_path, filename)
|
||||
if file_path.endswith(".safetensors"):
|
||||
state_dict = load_flat(file_path)
|
||||
else:
|
||||
state_dict = load_shard_state_dict(Path(file_path), use_safetensors=False)
|
||||
load_states_into_optimizer(optimizer.optim, state_dict, id_map, strict=True)
|
||||
loaded_file.add(filename)
|
||||
|
@ -672,7 +755,15 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
|||
# When pipeline is not used, let master rank directly save the collected state_dict.
|
||||
if self.tp_rank == 0:
|
||||
if use_async:
|
||||
super().save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors, use_async)
|
||||
from colossalai.utils.safetensors import save
|
||||
|
||||
if id(model) not in self.pinned_state_dicts:
|
||||
self.pinned_state_dicts[id(model)] = create_pinned_state_dict(state_dict)
|
||||
for name, param in state_dict.items():
|
||||
self.pinned_state_dicts[id(model)][name].copy_(param)
|
||||
state_dict[name] = self.pinned_state_dicts[id(model)][name]
|
||||
writer = save(path=checkpoint, state_dict=state_dict)
|
||||
self.async_writers.append(writer)
|
||||
else:
|
||||
save_state_dict(state_dict, checkpoint, use_safetensors)
|
||||
else:
|
||||
|
@ -686,12 +777,14 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
|||
for _state_dict in state_dict_list:
|
||||
complete_state_dict.update(_state_dict)
|
||||
if use_async:
|
||||
|
||||
from colossalai.utils.safetensors import move_and_save
|
||||
from colossalai.utils.safetensors import save
|
||||
|
||||
if id(model) not in self.pinned_state_dicts:
|
||||
self.pinned_state_dicts[id(model)] = create_pinned_state_dict(state_dict)
|
||||
writer = move_and_save(checkpoint, state_dict, self.pinned_state_dicts[id(model)])
|
||||
self.pinned_state_dicts[id(model)] = create_pinned_state_dict(complete_state_dict)
|
||||
for name, param in complete_state_dict.items():
|
||||
self.pinned_state_dicts[id(model)][name].copy_(param)
|
||||
complete_state_dict[name] = self.pinned_state_dicts[id(model)][name]
|
||||
writer = save(path=checkpoint, state_dict=complete_state_dict)
|
||||
self.async_writers.append(writer)
|
||||
else:
|
||||
save_state_dict(complete_state_dict, checkpoint, use_safetensors)
|
||||
|
@ -757,6 +850,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
|||
# gather complete state from tp shards & dp shards
|
||||
param_id = optimizer.param_info["param2id"][id(working_param)]
|
||||
original_shape = optimizer.param_info["param2shape"][id(working_param)]
|
||||
|
||||
local_states[param_id] = HybridParallelCheckpointIO.gather_from_sharded_optimizer_state(
|
||||
state,
|
||||
working_param,
|
||||
|
@ -776,6 +870,18 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
|||
]
|
||||
state_dict = {"param_groups": param_groups, "state": local_states}
|
||||
if self.coordinator.is_master():
|
||||
if use_async:
|
||||
from colossalai.utils.safetensors import save
|
||||
|
||||
flatten_state_dict, metadata = _flatten_optim_state_dict(state_dict)
|
||||
if id(optimizer) not in self.pinned_state_dicts:
|
||||
self.pinned_state_dicts = create_pinned_state_dict(flatten_state_dict)
|
||||
for k, v in flatten_state_dict.items():
|
||||
self.pinned_state_dicts[k].copy_(v)
|
||||
flatten_state_dict[k] = self.pinned_state_dicts[k]
|
||||
writer = save(path=checkpoint, state_dict=flatten_state_dict, metadata=metadata)
|
||||
self.async_writers.append(writer)
|
||||
else:
|
||||
save_state_dict(state_dict, checkpoint, use_safetensors=False)
|
||||
else:
|
||||
# When pipeline is used, first collect state_dict from every pipeline stage, then save the complete state_dict.
|
||||
|
@ -792,6 +898,18 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
|||
state_dict = {"param_groups": param_groups, "state": dict()}
|
||||
for _states in states_list:
|
||||
state_dict["state"].update(_states)
|
||||
if use_async:
|
||||
from colossalai.utils.safetensors import save
|
||||
|
||||
flatten_state_dict, metadata = _flatten_optim_state_dict(state_dict)
|
||||
if id(optimizer) not in self.pinned_state_dicts:
|
||||
self.pinned_state_dicts = create_pinned_state_dict(flatten_state_dict)
|
||||
for k, v in flatten_state_dict.items():
|
||||
self.pinned_state_dicts[k].copy_(v)
|
||||
flatten_state_dict[k] = self.pinned_state_dicts[k]
|
||||
writer = save(path=checkpoint, state_dict=flatten_state_dict, metadata=metadata)
|
||||
self.async_writers.append(writer)
|
||||
else:
|
||||
save_state_dict(state_dict, checkpoint, use_safetensors=False)
|
||||
|
||||
def load_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str):
|
||||
|
@ -818,6 +936,9 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
|||
assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before loading!"
|
||||
|
||||
# Complete optimizer state_dict loaded from checkpoint, need to be processed later.
|
||||
if checkpoint.endswith(".safetensors"):
|
||||
state_dict = load_flat(checkpoint)
|
||||
else:
|
||||
state_dict = load_state_dict(checkpoint)
|
||||
|
||||
# Load param_groups.
|
||||
|
@ -872,6 +993,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
|||
use_zero: bool,
|
||||
inplace: bool,
|
||||
device: torch.device = torch.device("cpu"),
|
||||
pinned_state_dicts: Optional[Dict[str, torch.Tensor]] = None,
|
||||
) -> OrderedDict:
|
||||
"""
|
||||
With given parameter and its optimizer states, gather the complete optimizer state for saving.
|
||||
|
@ -895,6 +1017,8 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
|||
state_ = state if inplace else copy.deepcopy(state)
|
||||
|
||||
for k, v in state_.items():
|
||||
if v is None:
|
||||
continue
|
||||
if isinstance(v, torch.Tensor) and k != "step":
|
||||
# First gather Zero shards.
|
||||
if use_zero:
|
||||
|
@ -915,6 +1039,12 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
|||
v = init_as_padded_tensor(v, v.shape[padding_dim], original_shape[padding_dim], padding_dim)
|
||||
v = to_unpadded_tensor(v)
|
||||
|
||||
if pinned_state_dicts is not None:
|
||||
if k not in pinned_state_dicts:
|
||||
pinned_state_dicts[k] = torch.empty_like(v, pin_memory=True, device="cpu")
|
||||
pinned_state_dicts[k].copy_(v)
|
||||
state_[k] = pinned_state_dicts[k]
|
||||
else:
|
||||
state_[k] = v.detach().clone().to(device)
|
||||
|
||||
return state_
|
||||
|
|
|
@ -44,12 +44,13 @@ class MoECheckpointIO(HybridParallelCheckpointIO):
|
|||
global_dp_group: ProcessGroup,
|
||||
pp_group: ProcessGroup,
|
||||
tp_group: ProcessGroup,
|
||||
sp_group: ProcessGroup,
|
||||
ep_group: ProcessGroup,
|
||||
moe_dp_group: ProcessGroup,
|
||||
zero_stage: int,
|
||||
verbose: bool = True,
|
||||
) -> None:
|
||||
super().__init__(global_dp_group, pp_group, tp_group, zero_stage, verbose)
|
||||
super().__init__(global_dp_group, pp_group, tp_group, sp_group, zero_stage, verbose)
|
||||
self.global_dp_group = global_dp_group
|
||||
self.global_dp_rank = dist.get_rank(global_dp_group)
|
||||
self.global_dp_size = dist.get_world_size(global_dp_group)
|
||||
|
@ -158,7 +159,7 @@ class MoECheckpointIO(HybridParallelCheckpointIO):
|
|||
state_dict_shard = MoECheckpointIO._model_sharder(model, size_per_shard=size_per_shard)
|
||||
weights_name, save_index_file = get_model_base_filenames(prefix, use_safetensors)
|
||||
index_file = CheckpointIndexFile(checkpoint)
|
||||
control_saving = self.tp_rank == 0
|
||||
control_saving = self.tp_rank == 0 and self.sp_rank == 0
|
||||
|
||||
if self.pp_size == 1 and self.ep_size == 1:
|
||||
# When pipeline is not used, save the model shards as in general checkpointIO
|
||||
|
@ -415,7 +416,7 @@ class MoECheckpointIO(HybridParallelCheckpointIO):
|
|||
# e.g. dp_size = 4, moe_dp_size = 2, ep_size = 2 and use gather
|
||||
# rank 0 saves moe & non-moe params; rank 1 only saves moe params
|
||||
# rank 3 & 4 save nothing
|
||||
control_saving = self.tp_rank == 0 and self.moe_dp_rank == 0
|
||||
control_saving = self.tp_rank == 0 and self.moe_dp_rank == 0 and self.sp_rank == 0
|
||||
|
||||
if self.pp_size == 1 and self.ep_size == 1:
|
||||
# When pipeline is not used, save the optimizer shards as in general checkpointIO
|
||||
|
|
|
@ -19,6 +19,7 @@ from colossalai.tensor.d_tensor import (
|
|||
to_global,
|
||||
to_global_for_customized_distributed_tensor,
|
||||
)
|
||||
from colossalai.utils.safetensors import _flatten_optim_state_dict
|
||||
|
||||
SAFE_WEIGHTS_NAME = "model.safetensors"
|
||||
WEIGHTS_NAME = "pytorch_model.bin"
|
||||
|
@ -266,6 +267,63 @@ def save_state_dict_shards(
|
|||
|
||||
|
||||
def async_save_state_dict_shards(
|
||||
sharded_state_dict: Iterator[Tuple[OrderedDict, int]],
|
||||
checkpoint: str,
|
||||
index_file: "CheckpointIndexFile",
|
||||
base_filename: str,
|
||||
is_master: bool,
|
||||
use_pp_format: bool = False,
|
||||
state_preprocess: bool = False,
|
||||
) -> Tuple[int, list]:
|
||||
"""
|
||||
Save sharded state dict only on master rank, this method can be used by both model and optimizer states.
|
||||
Args:
|
||||
sharded_state_dict (Iterator[Tuple[OrderedDict, int]]): a generator of shards, each shard contains state dict and shard size.
|
||||
checkpoint (str): The path of checkpoint directory as string.
|
||||
index_file (CheckpointIndexFile): The index file object to be updated.
|
||||
base_filename (str): Decides the prefix of filenames of shards.
|
||||
is_master (bool): Whether current rank is main process.
|
||||
use_safetensors (bool, optional): Whether to use safetensors to save checkpoint. Defaults to False.
|
||||
use_pp_format: (bool, optional): Whether to save the files in pipeline format including stage information. Defaults to False.
|
||||
|
||||
Returns:
|
||||
int: the total size of shards
|
||||
"""
|
||||
from colossalai.utils.safetensors import save
|
||||
|
||||
total_size = 0
|
||||
shard_filenames = []
|
||||
writers = []
|
||||
for idx, shard_pair in enumerate(sharded_state_dict):
|
||||
shard, current_size = shard_pair
|
||||
# Just loop over the sharder and gather to other ranks if not master
|
||||
if not is_master:
|
||||
del shard
|
||||
continue
|
||||
shard_file = get_shard_filename(base_filename, idx)
|
||||
total_size = total_size + current_size
|
||||
for key in shard.keys():
|
||||
index_file.append_weight_map(key, shard_file)
|
||||
checkpoint_file_path = os.path.join(checkpoint, shard_file)
|
||||
|
||||
if state_preprocess:
|
||||
state_dict, _ = _flatten_optim_state_dict(state_dict=shard, seperator=".")
|
||||
else:
|
||||
state_dict = shard
|
||||
|
||||
# Only save on master rank.
|
||||
writer = save(checkpoint_file_path, state_dict=state_dict)
|
||||
writers.append(writer)
|
||||
shard_filenames.append(shard_file)
|
||||
del shard
|
||||
|
||||
# Clean folder, deleted unneeded files.
|
||||
clean_folder(checkpoint, base_filename, shard_filenames, is_master=is_master, use_pp_format=use_pp_format)
|
||||
|
||||
return total_size, writers
|
||||
|
||||
|
||||
def async_move_save_state_dict_shards(
|
||||
sharded_state_dict: Iterator[Tuple[OrderedDict, int]],
|
||||
checkpoint: str,
|
||||
index_file: "CheckpointIndexFile",
|
||||
|
@ -273,6 +331,7 @@ def async_save_state_dict_shards(
|
|||
is_master: bool,
|
||||
pinned_state_dict: Optional[Dict[str, torch.Tensor]],
|
||||
use_pp_format: bool = False,
|
||||
state_preprocess: bool = False,
|
||||
) -> Tuple[int, Dict[str, torch.Tensor], list]:
|
||||
"""
|
||||
Save sharded state dict only on master rank, this method can be used by both model and optimizer states.
|
||||
|
@ -309,14 +368,19 @@ def async_save_state_dict_shards(
|
|||
index_file.append_weight_map(key, shard_file)
|
||||
checkpoint_file_path = os.path.join(checkpoint, shard_file)
|
||||
|
||||
if pinned_state_dict is not None:
|
||||
sub_pinned_state_dict = {k: pinned_state_dict[k] for k in shard.keys()}
|
||||
if state_preprocess:
|
||||
state_dict, _ = _flatten_optim_state_dict(state_dict=shard)
|
||||
else:
|
||||
sub_pinned_state_dict = create_pinned_state_dict(shard)
|
||||
state_dict = shard
|
||||
|
||||
if pinned_state_dict is not None:
|
||||
sub_pinned_state_dict = {k: pinned_state_dict[k] for k in state_dict.keys()}
|
||||
else:
|
||||
sub_pinned_state_dict = create_pinned_state_dict(state_dict)
|
||||
returned_state_dict.update(sub_pinned_state_dict)
|
||||
|
||||
# Only save on master rank.
|
||||
writer = move_and_save(checkpoint_file_path, shard, sub_pinned_state_dict)
|
||||
writer = move_and_save(checkpoint_file_path, state_dict, sub_pinned_state_dict)
|
||||
writers.append(writer)
|
||||
shard_filenames.append(shard_file)
|
||||
del shard
|
||||
|
@ -327,7 +391,11 @@ def async_save_state_dict_shards(
|
|||
return total_size, returned_state_dict, writers
|
||||
|
||||
|
||||
def shard_model_checkpoint(state_dict: torch.Tensor, max_shard_size: int = 1024) -> Iterator[Tuple[OrderedDict, int]]:
|
||||
def shard_model_checkpoint(
|
||||
state_dict: torch.Tensor,
|
||||
max_shard_size: int = 1024,
|
||||
pinned_state_dicts: Optional[Dict[int, Dict[str, torch.Tensor]]] = None,
|
||||
) -> Iterator[Tuple[OrderedDict, int]]:
|
||||
"""
|
||||
Splits a model state dictionary in sub-checkpoints so that the final size of each sub-checkpoint does not exceed a
|
||||
given size.
|
||||
|
@ -336,6 +404,11 @@ def shard_model_checkpoint(state_dict: torch.Tensor, max_shard_size: int = 1024)
|
|||
|
||||
for key, weight in state_dict.items():
|
||||
if not is_distributed_tensor(weight):
|
||||
if pinned_state_dicts is not None:
|
||||
if key not in pinned_state_dicts:
|
||||
pinned_state_dicts[key] = torch.empty_like(weight, pin_memory=True, device="cpu")
|
||||
pinned_state_dicts[key].copy_(weight)
|
||||
weight = pinned_state_dicts[key]
|
||||
block, block_size = state_dict_sharder.append_param(key, weight)
|
||||
|
||||
if block != None:
|
||||
|
@ -345,7 +418,9 @@ def shard_model_checkpoint(state_dict: torch.Tensor, max_shard_size: int = 1024)
|
|||
yield state_dict_sharder.current_block, state_dict_sharder.current_block_size
|
||||
|
||||
|
||||
def shard_optimizer_checkpoint(state_dict: dict, max_shard_size: int = 1024) -> Iterator[Tuple[OrderedDict, int]]:
|
||||
def shard_optimizer_checkpoint(
|
||||
state_dict: dict, max_shard_size: int = 1024, pinned_state_dicts: Optional[Dict[str, torch.Tensor]] = None
|
||||
) -> Iterator[Tuple[OrderedDict, int]]:
|
||||
"""
|
||||
Splits an optimizer state dictionary in sub-checkpoints so that the final size of each sub-checkpoint does not exceed a
|
||||
given size.
|
||||
|
@ -356,6 +431,15 @@ def shard_optimizer_checkpoint(state_dict: dict, max_shard_size: int = 1024) ->
|
|||
state_dict_sharder = StateDictSharder(max_shard_size)
|
||||
|
||||
for param_id, state in states.items():
|
||||
if pinned_state_dicts is not None:
|
||||
if param_id not in pinned_state_dicts:
|
||||
pinned_state_dicts[param_id] = {}
|
||||
for k, v in state.items():
|
||||
if k not in pinned_state_dicts[param_id]:
|
||||
pinned_state_dicts[param_id][k] = torch.empty_like(v, pin_memory=True, device="cpu")
|
||||
pinned_state_dicts[param_id][k].copy_(v)
|
||||
state[k] = pinned_state_dicts[param_id][k]
|
||||
|
||||
block, block_size = state_dict_sharder.append_optim_state(param_id, state)
|
||||
if block != None:
|
||||
yield block, block_size
|
||||
|
|
|
@ -71,6 +71,8 @@ def _flatten_optim_state_dict(state_dict: dict, seperator: str = ".") -> Tuple[d
|
|||
|
||||
for idx, d in states.items():
|
||||
for k, v in d.items():
|
||||
if v is None:
|
||||
continue
|
||||
nested_key = f"state{seperator}{idx}{seperator}{k}"
|
||||
if not isinstance(v, torch.Tensor):
|
||||
non_tensor_keys.append(nested_key)
|
||||
|
@ -87,7 +89,8 @@ def _flatten_optim_state_dict(state_dict: dict, seperator: str = ".") -> Tuple[d
|
|||
|
||||
def _unflatten_optim_state_dict(flat_dict: dict, metadata: Optional[dict] = None, seperator: str = "."):
|
||||
state_dict = {}
|
||||
if metadata is not None:
|
||||
|
||||
if metadata is not None and "non_tensor_keys" in metadata:
|
||||
non_tensor_keys = json.loads(metadata["non_tensor_keys"])
|
||||
else:
|
||||
non_tensor_keys = []
|
||||
|
@ -128,8 +131,10 @@ def prepare(
|
|||
header = {}
|
||||
offset = 0
|
||||
|
||||
header_metadata = {"format": "pt"}
|
||||
if metadata is not None:
|
||||
header["__metadata__"] = metadata
|
||||
header_metadata.update(metadata)
|
||||
header["__metadata__"] = header_metadata
|
||||
|
||||
for name, tensor in data.items():
|
||||
n = tensor.numel() * tensor.element_size()
|
||||
|
@ -172,8 +177,9 @@ def move_and_save(
|
|||
path: str,
|
||||
state_dict: Dict[str, torch.Tensor],
|
||||
state_dict_pinned: Optional[Dict[str, torch.Tensor]] = None,
|
||||
metadata: Optional[Dict[str, str]] = None,
|
||||
) -> None:
|
||||
prepared_data, _, tensor_keys = prepare(state_dict)
|
||||
prepared_data, _, tensor_keys = prepare(state_dict, metadata)
|
||||
n, header_bytes, _ = prepared_data.n, prepared_data.header_bytes, prepared_data.offset
|
||||
f_writer = AsyncFileWriter(path, n_entries=ASYNC_WRITE_ENTRIES, backend="pthread", n_tasks=2 + len(tensor_keys))
|
||||
f_writer.write(n.to_bytes(8, byteorder="little"))
|
||||
|
@ -188,9 +194,9 @@ def move_and_save(
|
|||
return f_writer
|
||||
|
||||
|
||||
def load_flat(checkpoint_path):
|
||||
def load_flat(checkpoint_path, seperator: str = "."):
|
||||
with safe_open(checkpoint_path, framework="pt") as f:
|
||||
metadata = f.metadata()
|
||||
state_dict_load = load_file(checkpoint_path)
|
||||
state_dict = _unflatten_optim_state_dict(state_dict_load, metadata)
|
||||
state_dict = _unflatten_optim_state_dict(state_dict_load, metadata, seperator)
|
||||
return state_dict
|
||||
|
|
|
@ -903,6 +903,7 @@ class GeminiDDP(ModelWrapper):
|
|||
keep_vars: bool = False,
|
||||
max_shard_size: int = 1024,
|
||||
only_rank_0: bool = True,
|
||||
pinned_state_dicts: Optional[Dict[str, torch.Tensor]] = None,
|
||||
) -> Iterator[Tuple[OrderedDict, int]]:
|
||||
"""Returns dictionaries containing a whole state of the module one by one. The max size of dictionary shard is specified by ``max_shard_size``.
|
||||
|
||||
|
@ -943,6 +944,13 @@ class GeminiDDP(ModelWrapper):
|
|||
gathered_param_buffer.update(self._get_chunk_to_save_data(chunk, only_rank_0))
|
||||
gathered_param = gathered_param_buffer.pop(param_to_save)
|
||||
|
||||
if pinned_state_dicts is not None:
|
||||
if (prefix + name) not in pinned_state_dicts:
|
||||
pinned_state_dicts[prefix + name] = torch.empty_like(
|
||||
gathered_param, pin_memory=True, device="cpu"
|
||||
)
|
||||
pinned_state_dicts[prefix + name].copy_(gathered_param)
|
||||
gathered_param = pinned_state_dicts[prefix + name]
|
||||
block, block_size = sharder.append_param(prefix + name, gathered_param)
|
||||
if block is not None:
|
||||
yield block, block_size
|
||||
|
@ -954,6 +962,11 @@ class GeminiDDP(ModelWrapper):
|
|||
for name, buf in self.named_buffers():
|
||||
if buf is not None and name not in self._non_persistent_buffers_set:
|
||||
buffer = buf if keep_vars else buf.detach()
|
||||
if pinned_state_dicts is not None:
|
||||
if (prefix + name) not in pinned_state_dicts:
|
||||
pinned_state_dicts[prefix + name] = torch.empty_like(buffer, pin_memory=True, device="cpu")
|
||||
pinned_state_dicts[prefix + name].copy_(buffer)
|
||||
buffer = pinned_state_dicts[prefix + name]
|
||||
block, block_size = sharder.append_param(prefix + name, buffer)
|
||||
if block is not None:
|
||||
yield block, block_size
|
||||
|
@ -964,6 +977,11 @@ class GeminiDDP(ModelWrapper):
|
|||
is not torch.nn.Module.get_extra_state
|
||||
):
|
||||
extra_state = self.get_extra_state()
|
||||
if pinned_state_dicts is not None:
|
||||
if extra_state_key not in pinned_state_dicts:
|
||||
pinned_state_dicts[extra_state_key] = torch.empty_like(extra_state, pin_memory=True, device="cpu")
|
||||
pinned_state_dicts[extra_state_key].copy_(extra_state)
|
||||
extra_state = pinned_state_dicts[extra_state_key]
|
||||
block, block_size = sharder.append_param(extra_state_key, extra_state)
|
||||
if block is not None:
|
||||
yield block, block_size
|
||||
|
|
|
@ -809,7 +809,11 @@ class GeminiOptimizer(OptimizerWrapper):
|
|||
self.optimizer_loading_epilogue()
|
||||
|
||||
def state_shard(
|
||||
self, prefix: str = "", max_shard_size: int = 1024, only_rank_0: bool = True
|
||||
self,
|
||||
prefix: str = "",
|
||||
max_shard_size: int = 1024,
|
||||
only_rank_0: bool = True,
|
||||
pinned_state_dicts: Optional[Dict[int, Dict[str, torch.Tensor]]] = None,
|
||||
) -> Iterator[Tuple[OrderedDict, int]]:
|
||||
"""Returns dictionaries containing shards of optimizer states one by one.
|
||||
The max size of each dictionary shard is specified by ``max_shard_size``.
|
||||
|
@ -829,6 +833,16 @@ class GeminiOptimizer(OptimizerWrapper):
|
|||
dist.barrier()
|
||||
state = self.collect_states(param_id=param_id, only_rank_0=only_rank_0)
|
||||
|
||||
if pinned_state_dicts is not None:
|
||||
if param_id not in pinned_state_dicts:
|
||||
pinned_state_dicts[param_id] = {}
|
||||
for k, v in state.items():
|
||||
if v is None:
|
||||
continue
|
||||
if k not in pinned_state_dicts[param_id]:
|
||||
pinned_state_dicts[param_id][k] = torch.empty_like(v, pin_memory=True, device="cpu")
|
||||
pinned_state_dicts[param_id][k].copy_(v)
|
||||
state[k] = pinned_state_dicts[param_id][k]
|
||||
block, block_size = sharder.append_optim_state(param_id, state)
|
||||
if block is not None:
|
||||
yield block, block_size
|
||||
|
|
|
@ -35,7 +35,10 @@ OPTIM_PLACEMENT_CONFIGS = [
|
|||
@parameterize("use_safetensors", [False, True])
|
||||
@parameterize("tp_size", [1, 2])
|
||||
@parameterize("zero_size", [2])
|
||||
def exam_state_dict_with_origin(placement_config, model_name, use_safetensors: bool, tp_size: int, zero_size: int):
|
||||
@parameterize("use_async", [False, True])
|
||||
def exam_state_dict_with_origin(
|
||||
placement_config, model_name, use_safetensors: bool, tp_size: int, zero_size: int, use_async: bool
|
||||
):
|
||||
from transformers import BertForSequenceClassification
|
||||
|
||||
(model_fn, data_gen_fn, output_transform_fn, _, _) = next(iter(model_zoo.get_sub_registry(model_name).values()))
|
||||
|
@ -70,7 +73,10 @@ def exam_state_dict_with_origin(placement_config, model_name, use_safetensors: b
|
|||
"",
|
||||
(model_size / 3),
|
||||
use_safetensors=use_safetensors,
|
||||
use_async=use_async,
|
||||
)
|
||||
booster.checkpoint_io._sync_d2h()
|
||||
booster.checkpoint_io._sync_io()
|
||||
dist.barrier()
|
||||
new_bert_model = BertForSequenceClassification.from_pretrained(pretrained_path)
|
||||
check_state_dict_equal(bert_model.state_dict(only_rank_0=False), new_bert_model.state_dict())
|
||||
|
@ -83,7 +89,10 @@ def exam_state_dict_with_origin(placement_config, model_name, use_safetensors: b
|
|||
@parameterize("size_per_shard", [32])
|
||||
@parameterize("tp_size", [1, 2])
|
||||
@parameterize("zero_size", [2])
|
||||
def exam_state_dict(placement_config, shard: bool, model_name: str, size_per_shard: int, tp_size: int, zero_size: int):
|
||||
@parameterize("use_async", [False, True])
|
||||
def exam_state_dict(
|
||||
placement_config, shard: bool, model_name: str, size_per_shard: int, tp_size: int, zero_size: int, use_async: bool
|
||||
):
|
||||
(model_fn, data_gen_fn, output_transform_fn, _, _) = next(iter(model_zoo.get_sub_registry(model_name).values()))
|
||||
criterion = lambda x: x.mean()
|
||||
enable_flash_attention = True if tp_size > 1 else False
|
||||
|
@ -124,14 +133,18 @@ def exam_state_dict(placement_config, shard: bool, model_name: str, size_per_sha
|
|||
with shared_tempdir() as tempdir:
|
||||
model_ckpt_path = f"{tempdir}/model"
|
||||
optimizer_ckpt_path = f"{tempdir}/optimizer"
|
||||
booster.save_model(
|
||||
model,
|
||||
model_ckpt_path,
|
||||
shard=shard,
|
||||
size_per_shard=size_per_shard,
|
||||
)
|
||||
|
||||
booster.save_optimizer(optimizer, optimizer_ckpt_path, shard=shard, size_per_shard=size_per_shard)
|
||||
if not shard and use_async:
|
||||
model_ckpt_path = f"{model_ckpt_path}.safetensors"
|
||||
optimizer_ckpt_path = f"{optimizer_ckpt_path}.safetensors"
|
||||
|
||||
booster.save_model(model, model_ckpt_path, shard=shard, size_per_shard=size_per_shard, use_async=use_async)
|
||||
|
||||
booster.save_optimizer(
|
||||
optimizer, optimizer_ckpt_path, shard=shard, size_per_shard=size_per_shard, use_async=use_async
|
||||
)
|
||||
booster.checkpoint_io._sync_d2h()
|
||||
booster.checkpoint_io._sync_io()
|
||||
dist.barrier()
|
||||
|
||||
booster.load_model(new_model, model_ckpt_path)
|
||||
|
@ -155,8 +168,18 @@ def exam_state_dict(placement_config, shard: bool, model_name: str, size_per_sha
|
|||
loss = criterion(output[output_key])
|
||||
booster.backward(loss, new_optimizer)
|
||||
new_optimizer.step()
|
||||
booster.save_model(new_model, model_ckpt_path, shard=shard)
|
||||
booster.save_optimizer(new_optimizer, optimizer_ckpt_path, shard=shard)
|
||||
|
||||
with shared_tempdir() as new_tempdir:
|
||||
model_ckpt_path = f"{new_tempdir}/model"
|
||||
optimizer_ckpt_path = f"{new_tempdir}/optimizer"
|
||||
|
||||
if not shard and use_async:
|
||||
model_ckpt_path = f"{model_ckpt_path}.safetensors"
|
||||
optimizer_ckpt_path = f"{optimizer_ckpt_path}.safetensors"
|
||||
booster.save_model(new_model, model_ckpt_path, shard=shard, use_async=use_async)
|
||||
booster.save_optimizer(new_optimizer, optimizer_ckpt_path, shard=shard, use_async=use_async)
|
||||
booster.checkpoint_io._sync_d2h()
|
||||
booster.checkpoint_io._sync_io()
|
||||
|
||||
|
||||
def exam_lazy_from_pretrained():
|
||||
|
|
|
@ -19,7 +19,8 @@ from colossalai.testing import check_state_dict_equal, clear_cache_before_run, p
|
|||
|
||||
@clear_cache_before_run()
|
||||
@parameterize("use_safetensors", [True, False])
|
||||
def test_unsharded_checkpoint(use_safetensors: bool):
|
||||
@parameterize("use_async", [False, True])
|
||||
def test_unsharded_checkpoint(use_safetensors: bool, use_async: bool):
|
||||
# create a model and optimizer
|
||||
model = resnet18()
|
||||
optimizer = Adam(model.parameters(), lr=0.001)
|
||||
|
@ -36,18 +37,21 @@ def test_unsharded_checkpoint(use_safetensors: bool):
|
|||
lr_scheduler.step()
|
||||
|
||||
# create a temp file for checkpoint
|
||||
if use_safetensors:
|
||||
if use_async or use_safetensors:
|
||||
suffix = ".safetensors"
|
||||
else:
|
||||
suffix = ".bin"
|
||||
model_ckpt_tempfile = tempfile.NamedTemporaryFile(suffix=suffix)
|
||||
if use_async:
|
||||
optimizer_ckpt_tempfile = tempfile.NamedTemporaryFile(suffix=suffix)
|
||||
else:
|
||||
optimizer_ckpt_tempfile = tempfile.NamedTemporaryFile()
|
||||
lr_scheduler_ckpt_tempfile = tempfile.NamedTemporaryFile()
|
||||
|
||||
# save the model, optimizer, lr_scheduler
|
||||
ckpt_io = GeneralCheckpointIO()
|
||||
ckpt_io.save_model(model, model_ckpt_tempfile.name, use_safetensors=use_safetensors)
|
||||
ckpt_io.save_optimizer(optimizer, optimizer_ckpt_tempfile.name)
|
||||
ckpt_io.save_model(model, model_ckpt_tempfile.name, use_safetensors=use_safetensors, use_async=use_async)
|
||||
ckpt_io.save_optimizer(optimizer, optimizer_ckpt_tempfile.name, use_async=use_async)
|
||||
ckpt_io.save_lr_scheduler(lr_scheduler, lr_scheduler_ckpt_tempfile.name)
|
||||
|
||||
# create new model
|
||||
|
@ -55,6 +59,9 @@ def test_unsharded_checkpoint(use_safetensors: bool):
|
|||
new_optimizer = Adam(new_model.parameters(), lr=0.001)
|
||||
new_lr_scheduler = CosineAnnealingWarmupLR(optimizer, total_steps=10)
|
||||
|
||||
ckpt_io._sync_d2h()
|
||||
ckpt_io._sync_io()
|
||||
|
||||
# load the model, optimizer, lr_scheduler
|
||||
ckpt_io.load_model(new_model, model_ckpt_tempfile.name)
|
||||
ckpt_io.load_optimizer(new_optimizer, optimizer_ckpt_tempfile.name)
|
||||
|
@ -66,7 +73,8 @@ def test_unsharded_checkpoint(use_safetensors: bool):
|
|||
|
||||
|
||||
@pytest.mark.parametrize("use_safetensors", [True, False])
|
||||
def test_sharded_model_checkpoint(use_safetensors: bool):
|
||||
@pytest.mark.parametrize("use_async", [False, True])
|
||||
def test_sharded_model_checkpoint(use_safetensors: bool, use_async: bool):
|
||||
# create a model and optimizer
|
||||
model = resnet18()
|
||||
optimizer = Adam(model.parameters(), lr=0.001)
|
||||
|
@ -79,21 +87,20 @@ def test_sharded_model_checkpoint(use_safetensors: bool):
|
|||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
# create a temp file for checkpoint
|
||||
if use_safetensors:
|
||||
pass
|
||||
else:
|
||||
pass
|
||||
|
||||
model_ckpt_dir = tempfile.TemporaryDirectory()
|
||||
optimizer_ckpt_tempfile = tempfile.NamedTemporaryFile()
|
||||
|
||||
# save the model and optimizer
|
||||
ckpt_io = GeneralCheckpointIO()
|
||||
|
||||
ckpt_io.save_model(model, model_ckpt_dir.name, True, True, "", 10, use_safetensors=use_safetensors)
|
||||
ckpt_io.save_model(
|
||||
model, model_ckpt_dir.name, True, True, "", 10, use_safetensors=use_safetensors, use_async=use_async
|
||||
)
|
||||
ckpt_io.save_optimizer(optimizer, optimizer_ckpt_tempfile.name, shard=False)
|
||||
|
||||
ckpt_io._sync_d2h()
|
||||
ckpt_io._sync_io()
|
||||
|
||||
# create new model
|
||||
new_model = resnet18()
|
||||
new_optimizer = Adam(new_model.parameters(), lr=0.001)
|
||||
|
@ -106,7 +113,8 @@ def test_sharded_model_checkpoint(use_safetensors: bool):
|
|||
check_state_dict_equal(optimizer.state_dict(), new_optimizer.state_dict())
|
||||
|
||||
|
||||
def test_sharded_optimizer_checkpoint():
|
||||
@pytest.mark.parametrize("use_async", [False, True])
|
||||
def test_sharded_optimizer_checkpoint(use_async: bool):
|
||||
# create a model and optimizer
|
||||
model = resnet18()
|
||||
optimizer = Adam(model.parameters(), lr=0.001)
|
||||
|
@ -128,7 +136,10 @@ def test_sharded_optimizer_checkpoint():
|
|||
ckpt_io = GeneralCheckpointIO()
|
||||
|
||||
ckpt_io.save_model(model, model_ckpt_dir.name, True, True, "", 10, use_safetensors=False)
|
||||
ckpt_io.save_optimizer(optimizer, optimizer_ckpt_dir.name, shard=True, size_per_shard=10)
|
||||
ckpt_io.save_optimizer(optimizer, optimizer_ckpt_dir.name, shard=True, size_per_shard=10, use_async=use_async)
|
||||
|
||||
ckpt_io._sync_d2h()
|
||||
ckpt_io._sync_io()
|
||||
|
||||
# create new model
|
||||
new_model = resnet18()
|
||||
|
@ -148,9 +159,16 @@ def test_sharded_optimizer_checkpoint():
|
|||
loss.backward()
|
||||
new_optimizer.step()
|
||||
|
||||
# create temp directories for checkpoint
|
||||
model_ckpt_dir = tempfile.TemporaryDirectory()
|
||||
optimizer_ckpt_dir = tempfile.TemporaryDirectory()
|
||||
|
||||
# save the newly got optimizer
|
||||
ckpt_io.save_model(new_model, model_ckpt_dir.name, True, True, "", 10, use_safetensors=False)
|
||||
ckpt_io.save_optimizer(new_optimizer, optimizer_ckpt_dir.name, shard=True, size_per_shard=10)
|
||||
ckpt_io.save_optimizer(new_optimizer, optimizer_ckpt_dir.name, shard=True, size_per_shard=10, use_async=use_async)
|
||||
|
||||
ckpt_io._sync_d2h()
|
||||
ckpt_io._sync_io()
|
||||
|
||||
# create another new model
|
||||
new_new_model = resnet18()
|
||||
|
@ -164,7 +182,8 @@ def test_sharded_optimizer_checkpoint():
|
|||
check_state_dict_equal(new_optimizer.state_dict(), new_new_optimizer.state_dict())
|
||||
|
||||
|
||||
def test_sharded_optimizer_multiple_param_groups():
|
||||
@pytest.mark.parametrize("use_async", [False, True])
|
||||
def test_sharded_optimizer_multiple_param_groups(use_async: bool):
|
||||
# create a model and optimizer
|
||||
model = resnet18()
|
||||
optimizer = Adam(
|
||||
|
@ -188,7 +207,10 @@ def test_sharded_optimizer_multiple_param_groups():
|
|||
ckpt_io = GeneralCheckpointIO()
|
||||
|
||||
ckpt_io.save_model(model, model_ckpt_dir.name, True, True, "", 10, use_safetensors=False)
|
||||
ckpt_io.save_optimizer(optimizer, optimizer_ckpt_dir.name, shard=True, size_per_shard=10)
|
||||
ckpt_io.save_optimizer(optimizer, optimizer_ckpt_dir.name, shard=True, size_per_shard=10, use_async=use_async)
|
||||
|
||||
ckpt_io._sync_d2h()
|
||||
ckpt_io._sync_io()
|
||||
|
||||
# create new model
|
||||
new_model = resnet18()
|
||||
|
|
|
@ -38,12 +38,13 @@ else:
|
|||
]
|
||||
|
||||
|
||||
@parameterize("shard", [True, False])
|
||||
@parameterize("shard", [False, True])
|
||||
@parameterize("model_name", ["transformers_llama_for_causal_lm"])
|
||||
@parameterize("size_per_shard", [32])
|
||||
@parameterize("test_config", TEST_CONFIGS)
|
||||
@parameterize("use_async", [False, True])
|
||||
@clear_cache_before_run()
|
||||
def exam_state_dict(shard: bool, model_name: str, size_per_shard: int, test_config: dict):
|
||||
def exam_state_dict(shard: bool, model_name: str, size_per_shard: int, test_config: dict, use_async: bool):
|
||||
(model_fn, data_gen_fn, output_transform_fn, loss_fn, _) = next(
|
||||
iter(model_zoo.get_sub_registry(model_name).values())
|
||||
)
|
||||
|
@ -85,8 +86,16 @@ def exam_state_dict(shard: bool, model_name: str, size_per_shard: int, test_conf
|
|||
with shared_tempdir() as tempdir:
|
||||
model_ckpt_path = f"{tempdir}/model"
|
||||
optimizer_ckpt_path = f"{tempdir}/optimizer"
|
||||
booster.save_model(model, model_ckpt_path, shard=shard, size_per_shard=size_per_shard)
|
||||
booster.save_optimizer(optimizer, optimizer_ckpt_path, shard=shard, size_per_shard=size_per_shard)
|
||||
if not shard and use_async:
|
||||
model_ckpt_path = f"{model_ckpt_path}.safetensors"
|
||||
optimizer_ckpt_path = f"{optimizer_ckpt_path}.safetensors"
|
||||
|
||||
booster.save_model(model, model_ckpt_path, shard=shard, size_per_shard=size_per_shard, use_async=use_async)
|
||||
booster.save_optimizer(
|
||||
optimizer, optimizer_ckpt_path, shard=shard, size_per_shard=size_per_shard, use_async=use_async
|
||||
)
|
||||
booster.checkpoint_io._sync_d2h()
|
||||
booster.checkpoint_io._sync_io()
|
||||
dist.barrier()
|
||||
|
||||
new_model = model_fn().cuda()
|
||||
|
|
|
@ -12,14 +12,15 @@ from colossalai.interface import OptimizerWrapper
|
|||
from colossalai.testing import check_state_dict_equal, parameterize, rerun_if_address_is_in_use, spawn
|
||||
|
||||
|
||||
@parameterize("shard", [True, False])
|
||||
@parameterize("shard", [False, True])
|
||||
@parameterize("size_per_shard", [16, 128])
|
||||
def check_torch_ddp_checkpointIO(shard: bool, size_per_shard: int):
|
||||
@parameterize("use_async", [False, True])
|
||||
def check_torch_ddp_checkpointIO(shard: bool, size_per_shard: int, use_async: bool):
|
||||
plugin = TorchDDPPlugin()
|
||||
booster = Booster(plugin=plugin)
|
||||
model = resnet18()
|
||||
criterion = lambda x: x.mean()
|
||||
optimizer = SGD((model.parameters()), lr=0.001)
|
||||
optimizer = SGD((model.parameters()), lr=0.001, momentum=0.5)
|
||||
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.1)
|
||||
model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion, lr_scheduler=scheduler)
|
||||
|
||||
|
@ -39,9 +40,18 @@ def check_torch_ddp_checkpointIO(shard: bool, size_per_shard: int):
|
|||
model_ckpt_path = f"{tempdir}/model"
|
||||
optimizer_ckpt_path = f"{tempdir}/optimizer"
|
||||
lr_scheduler_ckpt_path = f"{tempdir}/lr_scheduler"
|
||||
booster.save_model(model, model_ckpt_path, shard=shard, size_per_shard=size_per_shard)
|
||||
booster.save_optimizer(optimizer, optimizer_ckpt_path, shard=shard, size_per_shard=size_per_shard)
|
||||
|
||||
if not shard and use_async:
|
||||
model_ckpt_path = f"{model_ckpt_path}.safetensors"
|
||||
optimizer_ckpt_path = f"{optimizer_ckpt_path}.safetensors"
|
||||
|
||||
booster.save_model(model, model_ckpt_path, shard=shard, size_per_shard=size_per_shard, use_async=use_async)
|
||||
booster.save_optimizer(
|
||||
optimizer, optimizer_ckpt_path, shard=shard, size_per_shard=size_per_shard, use_async=use_async
|
||||
)
|
||||
booster.save_lr_scheduler(scheduler, lr_scheduler_ckpt_path)
|
||||
booster.checkpoint_io._sync_d2h()
|
||||
booster.checkpoint_io._sync_io()
|
||||
dist.barrier()
|
||||
|
||||
new_model = resnet18()
|
||||
|
|
|
@ -12,7 +12,7 @@ if version.parse(torch.__version__) >= version.parse("1.12.0"):
|
|||
from colossalai.booster.plugin import TorchFSDPPlugin
|
||||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
||||
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||
|
||||
|
||||
def compare_nested_dict(dict1, dict2):
|
||||
|
@ -43,7 +43,8 @@ def compare_nested_dict(dict1, dict2):
|
|||
return True
|
||||
|
||||
|
||||
def check_torch_fsdp_ckpt():
|
||||
@parameterize("use_async", [False, True])
|
||||
def check_torch_fsdp_ckpt(use_async: bool):
|
||||
model = resnet18()
|
||||
plugin = TorchFSDPPlugin()
|
||||
booster = Booster(plugin=plugin)
|
||||
|
@ -65,10 +66,17 @@ def check_torch_fsdp_ckpt():
|
|||
model_ckpt_path = f"{tempdir}/model"
|
||||
optim_ckpt_path = f"{tempdir}/optimizer"
|
||||
|
||||
if use_async:
|
||||
model_ckpt_path = f"{model_ckpt_path}.safetensors"
|
||||
optim_ckpt_path = f"{optim_ckpt_path}.safetensors"
|
||||
|
||||
run_model()
|
||||
|
||||
booster.save_model(fsdp_model, model_ckpt_path, shard=False)
|
||||
booster.save_optimizer(optimizer, optim_ckpt_path, shard=False)
|
||||
booster.save_model(fsdp_model, model_ckpt_path, shard=False, use_async=use_async)
|
||||
booster.save_optimizer(optimizer, optim_ckpt_path, shard=False, use_async=use_async)
|
||||
|
||||
booster.checkpoint_io._sync_d2h()
|
||||
booster.checkpoint_io._sync_io()
|
||||
|
||||
full_msd = fsdp_model.state_dict()
|
||||
# full_osd = FSDP.full_optim_state_dict(fsdp_model, optimizer)
|
||||
|
@ -106,8 +114,11 @@ def check_torch_fsdp_ckpt():
|
|||
|
||||
run_model()
|
||||
|
||||
booster.save_model(fsdp_model, model_ckpt_path, shard=True)
|
||||
booster.save_optimizer(optimizer, optim_ckpt_path, shard=True)
|
||||
booster.save_model(fsdp_model, model_ckpt_path, shard=True, use_async=use_async)
|
||||
booster.save_optimizer(optimizer, optim_ckpt_path, shard=True, use_async=use_async)
|
||||
|
||||
booster.checkpoint_io._sync_d2h()
|
||||
booster.checkpoint_io._sync_io()
|
||||
|
||||
full_msd = fsdp_model.unwrap().state_dict()
|
||||
full_osd = FSDP.full_optim_state_dict(optimizer.unwrap_model().unwrap(), optim=optimizer)
|
||||
|
|
Loading…
Reference in New Issue