[checkpointio]support asyncio for 3d (#6152)

* fix

* fix

* fix

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update utils.py

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
pull/6164/head
flybird11111 2024-12-23 10:24:22 +08:00 committed by GitHub
parent aaafb38851
commit 130229fdcb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
17 changed files with 776 additions and 188 deletions

View File

@ -17,6 +17,8 @@ from torch.utils.data.distributed import DistributedSampler
from colossalai.accelerator import get_accelerator
from colossalai.checkpoint_io import CheckpointIndexFile, CheckpointIO, GeneralCheckpointIO
from colossalai.checkpoint_io.utils import (
async_save_state_dict_shards,
create_pinned_state_dict,
get_model_base_filenames,
get_optimizer_base_filenames,
load_shard_state_dict,
@ -28,6 +30,7 @@ from colossalai.cluster import DistCoordinator, ProcessGroupMesh
from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.logging import get_dist_logger
from colossalai.shardformer import ShardConfig, ShardFormer
from colossalai.utils.safetensors import load_flat
from colossalai.zero import GeminiDDP, GeminiOptimizer
from colossalai.zero.gemini.memory_tracer import MemStats
@ -82,7 +85,15 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
state_dict = model.state_dict(only_rank_0=True)
if self.coordinator.is_master():
if use_async:
super().save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors, use_async)
from colossalai.utils.safetensors import save
if id(model) not in self.pinned_state_dicts:
self.pinned_state_dicts[id(model)] = create_pinned_state_dict(state_dict)
for k, v in state_dict.items():
self.pinned_state_dicts[id(model)][k].copy_(v)
state_dict[k] = self.pinned_state_dicts[id(model)][k]
writer = save(checkpoint, state_dict)
self.async_writers.append(writer)
else:
save_state_dict(state_dict, checkpoint, use_safetensors)
@ -106,6 +117,18 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
assert isinstance(optimizer, GeminiOptimizer), "Please boost the optimizer before saving!"
state_dict = optimizer.state_dict()
if self.coordinator.is_master():
if use_async:
from colossalai.utils.safetensors import _flatten_optim_state_dict, save
flatten_state_dict, metadata = _flatten_optim_state_dict(state_dict)
if id(optimizer) not in self.pinned_state_dicts:
self.pinned_state_dicts[id(optimizer)] = create_pinned_state_dict(flatten_state_dict)
for k, v in flatten_state_dict.items():
self.pinned_state_dicts[id(optimizer)][k].copy_(v)
flatten_state_dict[k] = self.pinned_state_dicts[id(optimizer)][k]
writer = save(checkpoint, flatten_state_dict, metadata)
self.async_writers.append(writer)
else:
save_state_dict(state_dict, checkpoint, use_safetensors=False)
def load_unsharded_optimizer(self, optimizer: GeminiOptimizer, checkpoint: str):
@ -137,17 +160,29 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
Path(checkpoint_path).mkdir(parents=True, exist_ok=True)
state_dict_shard = model.state_dict_shard(max_shard_size=max_shard_size, only_rank_0=True)
if use_async and self.coordinator.is_master():
if id(model) not in self.pinned_state_dicts:
self.pinned_state_dicts[id(model)] = {}
pinned_state_dicts = self.pinned_state_dicts[id(model)]
else:
pinned_state_dicts = None
state_dict_shard = model.state_dict_shard(
max_shard_size=max_shard_size, only_rank_0=True, pinned_state_dicts=pinned_state_dicts
)
weights_name, save_index_file = get_model_base_filenames(prefix, use_safetensors)
index_file = CheckpointIndexFile(checkpoint_path)
# Save shards of optimizer states.
is_master = self.coordinator.is_master()
if use_async:
super().save_sharded_model(
model, checkpoint_path, gather_dtensor, prefix, max_shard_size, use_safetensors, use_async
total_size, writers = async_save_state_dict_shards(
sharded_state_dict=state_dict_shard,
checkpoint=checkpoint_path,
index_file=index_file,
base_filename=weights_name,
is_master=is_master,
)
self.async_writers.extend(writers)
else:
total_size = save_state_dict_shards(
sharded_state_dict=state_dict_shard,
@ -201,7 +236,7 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
Path(checkpoint).mkdir(parents=True, exist_ok=True)
# Preparing file paths and index file.
states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix)
states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix, use_safetensors=use_async)
index_file = CheckpointIndexFile(checkpoint)
index_file.append_meta_data("param_groups", param_group_file)
@ -212,9 +247,28 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
torch.save(param_groups, group_file_path)
# States are broken into shards within max_shard_size.
state_dict_shard = optimizer.state_shard(prefix=prefix, max_shard_size=size_per_shard, only_rank_0=True)
if use_async and self.coordinator.is_master():
if id(optimizer) not in self.pinned_state_dicts:
self.pinned_state_dicts[id(optimizer)] = {}
pinned_state_dicts = self.pinned_state_dicts[id(optimizer)]
else:
pinned_state_dicts = None
state_dict_shard = optimizer.state_shard(
prefix=prefix, max_shard_size=size_per_shard, only_rank_0=True, pinned_state_dicts=pinned_state_dicts
)
# Save shards of optimizer states.
if use_async:
total_size, writers = async_save_state_dict_shards(
sharded_state_dict=state_dict_shard,
checkpoint=checkpoint,
index_file=index_file,
base_filename=states_name,
is_master=self.coordinator.is_master(),
state_preprocess=True,
)
self.async_writers.extend(writers)
else:
total_size = save_state_dict_shards(
sharded_state_dict=state_dict_shard,
checkpoint=checkpoint,
@ -264,6 +318,9 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
# Load optimizer states from shard files under checkpoint path.
# For each file, only load the states managed by current process.
for shard_file in checkpoint_files:
if shard_file.endswith(".safetensors"):
state_dict_shard = load_flat(shard_file)
else:
state_dict_shard = load_shard_state_dict(Path(shard_file), use_safetensors=False)
optimizer.load_param_states(state_dict_shard)
del state_dict_shard

View File

@ -1488,7 +1488,7 @@ class HybridParallelPlugin(PipelinePluginBase):
)
def get_checkpoint_io(self) -> CheckpointIO:
return HybridParallelCheckpointIO(self.dp_group, self.pp_group, self.tp_group, self.zero_stage)
return HybridParallelCheckpointIO(self.dp_group, self.pp_group, self.tp_group, self.sp_group, self.zero_stage)
def no_sync(self, model: Module, optimizer: OptimizerWrapper) -> Iterator[None]:
assert (

View File

@ -404,7 +404,13 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
def get_checkpoint_io(self) -> MoECheckpointIO:
return MoECheckpointIO(
self.dp_group, self.pp_group, self.tp_group, self.ep_group, self.moe_dp_group, self.zero_stage
self.dp_group,
self.pp_group,
self.tp_group,
self.sp_group,
self.ep_group,
self.moe_dp_group,
self.zero_stage,
)
def configure(

View File

@ -60,7 +60,7 @@ class TorchDDPCheckpointIO(GeneralCheckpointIO):
"""
assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before saving!"
if self.coordinator.is_master():
super().save_unsharded_optimizer(optimizer, checkpoint, gather_dtensor)
super().save_unsharded_optimizer(optimizer, checkpoint, gather_dtensor, use_async=use_async)
def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str):
"""

View File

@ -1,6 +1,6 @@
import os
from pathlib import Path
from typing import Callable, Dict, Iterable, Iterator, List, Optional, Tuple
from typing import Any, Callable, Dict, Iterable, Iterator, List, Optional, Tuple
import torch
import torch.nn as nn
@ -26,9 +26,11 @@ from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from torch.utils.data import DataLoader
from colossalai.checkpoint_io import CheckpointIndexFile, CheckpointIO, GeneralCheckpointIO, utils
from colossalai.checkpoint_io.utils import async_save_state_dict_shards, create_pinned_state_dict
from colossalai.cluster import DistCoordinator
from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.logging import get_dist_logger
from colossalai.utils.safetensors import load_flat
from .dp_plugin_base import DPPluginBase
@ -49,8 +51,36 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO):
def load_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: Path):
assert isinstance(optimizer, FSDPOptimizerWrapper), "Please boost the optimizer before loading!"
if checkpoint.endswith(".safetensors"):
checkpoint = load_flat(checkpoint, seperator=".")
else:
checkpoint = utils.load_state_dict(checkpoint)
fsdp_model = optimizer.unwrap_model()
full_optimizer_state = FSDP.full_optim_state_dict(fsdp_model, optim=optimizer, rank0_only=False)
start_index = 0
id2name = {}
def get_index_mapping(group: Dict[str, Any]) -> Dict[str, Any]:
nonlocal start_index
start_num = len(id2name)
id2name.update({i: p for i, p in enumerate(group["params"], start_index) if i not in id2name})
end_num = len(id2name)
start_index += end_num - start_num
for g in full_optimizer_state["param_groups"]:
get_index_mapping(g)
new_state = {}
for key, value in checkpoint["state"].items():
new_state[id2name[int(key)]] = value
checkpoint["state"] = new_state
for g in checkpoint["param_groups"]:
new_group = []
for param_id in g["params"]:
new_group.append(id2name[param_id])
g["params"] = new_group
sharded_osd = FSDP.scatter_full_optim_state_dict(checkpoint, fsdp_model)
optimizer.load_state_dict(sharded_osd)
@ -65,7 +95,21 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO):
cfg = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, cfg):
full_model_state = model.state_dict()
utils.save_state_dict(full_model_state, checkpoint_file_path=checkpoint, use_safetensors=use_safetensors)
if self.coordinator.is_master():
if use_async:
from colossalai.utils.safetensors import save
if id(model) not in self.pinned_state_dicts:
self.pinned_state_dicts[id(model)] = create_pinned_state_dict(full_model_state)
for k, v in full_model_state.items():
self.pinned_state_dicts[id(model)][k].copy_(v)
full_model_state[k] = self.pinned_state_dicts[id(model)][k]
writer = save(checkpoint, full_model_state)
self.async_writers.append(writer)
else:
utils.save_state_dict(
full_model_state, checkpoint_file_path=checkpoint, use_safetensors=use_safetensors
)
def save_unsharded_optimizer(
self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool, use_async: bool = False
@ -75,7 +119,42 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO):
"""
assert isinstance(optimizer, FSDPOptimizerWrapper), "Please boost the optimizer before saving!"
fsdp_model = optimizer.unwrap_model()
full_optimizer_state = FSDP.full_optim_state_dict(fsdp_model, optim=optimizer, rank0_only=True)
if self.coordinator.is_master():
# Save order indices instead of Tensors
name2id: Dict[str, int] = {}
start_index = 0
def pack_group(group: Dict[str, Any]) -> Dict[str, Any]:
nonlocal start_index
packed = {k: v for k, v in group.items() if k != "params"}
name2id.update({p: i for i, p in enumerate(group["params"], start_index) if p not in name2id})
packed["params"] = [name2id[p] for p in group["params"]]
start_index += len(packed["params"])
return packed
param_groups = [pack_group(g) for g in full_optimizer_state["param_groups"]]
full_optimizer_state["param_groups"] = param_groups
new_state = {}
for key, value in full_optimizer_state["state"].items():
new_state[name2id[key]] = value
full_optimizer_state["state"] = new_state
if use_async:
from colossalai.utils.safetensors import _flatten_optim_state_dict, save
flatten_state_dict, metadata = _flatten_optim_state_dict(full_optimizer_state, seperator=".")
if id(optimizer) not in self.pinned_state_dicts:
self.pinned_state_dicts[id(optimizer)] = create_pinned_state_dict(flatten_state_dict)
for k, v in flatten_state_dict.items():
self.pinned_state_dicts[id(optimizer)][k].copy_(v)
flatten_state_dict[k] = self.pinned_state_dicts[id(optimizer)][k]
writer = save(checkpoint, state_dict=flatten_state_dict, metadata=metadata)
self.async_writers.append(writer)
else:
utils.save_state_dict(full_optimizer_state, checkpoint_file_path=checkpoint, use_safetensors=False)
def save_sharded_model(
@ -102,12 +181,30 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO):
):
state_dict = model.unwrap().state_dict()
state_dict_shard = utils.shard_model_checkpoint(state_dict, max_shard_size=size_per_shard)
if use_async and self.coordinator.is_master():
if id(model) not in self.pinned_state_dicts:
self.pinned_state_dicts[id(model)] = {}
pinned_state_dicts = self.pinned_state_dicts[id(model)]
else:
pinned_state_dicts = None
state_dict_shard = utils.shard_model_checkpoint(
state_dict, max_shard_size=size_per_shard, pinned_state_dicts=pinned_state_dicts
)
weights_name, save_index_file = utils.get_model_base_filenames(prefix, use_safetensors)
index_file = CheckpointIndexFile(checkpoint_path)
# In general cases, is_master is set to True to get the right behavior.
if use_async:
total_size, writers = async_save_state_dict_shards(
sharded_state_dict=state_dict_shard,
checkpoint=checkpoint_path,
index_file=index_file,
base_filename=weights_name,
is_master=self.coordinator.is_master(),
)
self.async_writers.extend(writers)
else:
total_size = utils.save_state_dict_shards(
sharded_state_dict=state_dict_shard,
checkpoint=checkpoint_path,
@ -188,18 +285,58 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO):
)
if self.coordinator.is_master():
# Save order indices instead of Tensors
name2id: Dict[str, int] = {}
start_index = 0
def pack_group(group: Dict[str, Any]) -> Dict[str, Any]:
nonlocal start_index
packed = {k: v for k, v in group.items() if k != "params"}
name2id.update({p: i for i, p in enumerate(group["params"], start_index) if p not in name2id})
packed["params"] = [name2id[p] for p in group["params"]]
start_index += len(packed["params"])
return packed
param_groups = [pack_group(g) for g in fsdp_optim_state["param_groups"]]
fsdp_optim_state["param_groups"] = param_groups
new_state = {}
for key, value in fsdp_optim_state["state"].items():
new_state[name2id[key]] = value
fsdp_optim_state["state"] = new_state
# Preparing file paths and index file.
states_name, save_index_file, param_group_file = utils.get_optimizer_base_filenames(prefix)
states_name, save_index_file, param_group_file = utils.get_optimizer_base_filenames(
prefix, use_safetensors=use_async
)
index_file = CheckpointIndexFile(checkpoint)
index_file.append_meta_data("param_groups", param_group_file)
group_file_path = os.path.join(checkpoint, param_group_file)
utils.save_param_groups(fsdp_optim_state, group_file_path)
sharded_state = utils.shard_optimizer_checkpoint(fsdp_optim_state, max_shard_size=size_per_shard)
if use_async:
if id(optimizer) not in self.pinned_state_dicts:
self.pinned_state_dicts[id(optimizer)] = {}
pinned_state_dicts = self.pinned_state_dicts[id(optimizer)]
else:
pinned_state_dicts = None
sharded_state = utils.shard_optimizer_checkpoint(
fsdp_optim_state, max_shard_size=size_per_shard, pinned_state_dicts=pinned_state_dicts
)
# Save shards of optimizer states.
# In general cases, is_master is set to True to get the right behavior.
if use_async:
total_size, writers = async_save_state_dict_shards(
sharded_state_dict=sharded_state,
checkpoint=checkpoint,
index_file=index_file,
base_filename=states_name,
is_master=self.coordinator.is_master(),
state_preprocess=True,
)
self.async_writers.extend(writers)
else:
total_size = utils.save_state_dict_shards(
sharded_state_dict=sharded_state,
checkpoint=checkpoint,
@ -239,11 +376,39 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO):
fsdp_optim_state = {}
checkpoint_files, _ = ckpt_index_file.get_checkpoint_filenames()
for shard_file in checkpoint_files:
if shard_file.endswith(".safetensors"):
state_dict_shard = load_flat(shard_file, seperator=".")
else:
state_dict_shard = utils.load_shard_state_dict(Path(shard_file), use_safetensors=False)
fsdp_optim_state.update(state_dict_shard)
fsdp_optim_dict = dict(state=fsdp_optim_state, param_groups=saved_param_groups)
fsdp_model = optimizer.unwrap_model()
full_optimizer_state = FSDP.full_optim_state_dict(fsdp_model.unwrap(), optim=optimizer, rank0_only=False)
start_index = 0
id2name = {}
def get_index_mapping(group: Dict[str, Any]) -> Dict[str, Any]:
nonlocal start_index
start_num = len(id2name)
id2name.update({i: p for i, p in enumerate(group["params"], start_index) if i not in id2name})
end_num = len(id2name)
start_index += end_num - start_num
for g in full_optimizer_state["param_groups"]:
get_index_mapping(g)
new_state = {}
for key, value in fsdp_optim_dict["state"].items():
new_state[id2name[int(key)]] = value
fsdp_optim_dict["state"] = new_state
for g in fsdp_optim_dict["param_groups"]:
new_group = []
for param_id in g["params"]:
new_group.append(id2name[param_id])
g["params"] = new_group
with FSDP.state_dict_type(optimizer.unwrap_model().unwrap(), StateDictType.FULL_STATE_DICT):
fsdp_state = FSDP.optim_state_dict_to_load(
model=optimizer.unwrap_model().unwrap(), optim=optimizer, optim_state_dict=fsdp_optim_dict

View File

@ -8,10 +8,12 @@ from typing import Optional
import torch.nn as nn
from torch.optim import Optimizer
from colossalai.utils.safetensors import load_flat
from .checkpoint_io_base import CheckpointIO
from .index_file import CheckpointIndexFile
from .utils import (
async_save_state_dict_shards,
async_move_save_state_dict_shards,
create_pinned_state_dict,
get_model_base_filenames,
get_optimizer_base_filenames,
@ -47,10 +49,6 @@ class GeneralCheckpointIO(CheckpointIO):
):
state_dict = model.state_dict()
# TODO(FrankLeeeee): add support for gather_dtensor
if gather_dtensor:
pass
if use_async:
from colossalai.utils.safetensors import move_and_save
@ -58,7 +56,6 @@ class GeneralCheckpointIO(CheckpointIO):
self.pinned_state_dicts[id(model)] = create_pinned_state_dict(state_dict)
writer = move_and_save(checkpoint, state_dict, self.pinned_state_dicts[id(model)])
self.async_writers.append(writer)
else:
# save the checkpoint
save_state_dict(state_dict, checkpoint, use_safetensors)
@ -83,6 +80,9 @@ class GeneralCheckpointIO(CheckpointIO):
checkpoint_files, _ = ckpt_index_file.get_checkpoint_filenames()
for shard_file in checkpoint_files:
if shard_file.endswith(".safetensors"):
state_dict = load_flat(shard_file)
else:
state_dict = load_shard_state_dict(Path(shard_file), use_safetensors=False)
load_states_into_optimizer(optimizer, state_dict, id_map)
@ -116,7 +116,7 @@ class GeneralCheckpointIO(CheckpointIO):
sharded_state = shard_optimizer_checkpoint(state_dict, max_shard_size=size_per_shard)
# Preparing file paths and index file.
states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix)
states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix, use_safetensors=use_async)
index_file = CheckpointIndexFile(checkpoint)
# Store the information of param groups to param_group_file.
@ -126,6 +126,20 @@ class GeneralCheckpointIO(CheckpointIO):
# Save shards of optimizer states.
# In general cases, is_master is set to True to get the right behavior.
if use_async:
pinned_state_dict = self.pinned_state_dicts.get(id(optimizer), None)
total_size, new_pinned_state_dict, writers = async_move_save_state_dict_shards(
sharded_state_dict=sharded_state,
checkpoint=checkpoint,
index_file=index_file,
base_filename=states_name,
is_master=True,
pinned_state_dict=pinned_state_dict,
state_preprocess=True,
)
self.pinned_state_dicts[id(optimizer)] = new_pinned_state_dict
self.async_writers.extend(writers)
else:
total_size = save_state_dict_shards(
sharded_state_dict=sharded_state,
checkpoint=checkpoint,
@ -145,6 +159,9 @@ class GeneralCheckpointIO(CheckpointIO):
)
def load_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path):
if checkpoint.endswith(".safetensors"):
checkpoint = load_flat(checkpoint)
else:
checkpoint = load_state_dict(checkpoint)
optimizer.load_state_dict(checkpoint)
@ -156,7 +173,22 @@ class GeneralCheckpointIO(CheckpointIO):
use_async: bool = False,
):
# TODO(FrankLeeeee): handle distributed tensors
save_state_dict(optimizer.state_dict(), checkpoint, use_safetensors=False)
state_dict = optimizer.state_dict()
if use_async:
from colossalai.utils.safetensors import _flatten_optim_state_dict, move_and_save
flatten_state_dict, metadata = _flatten_optim_state_dict(state_dict)
if id(optimizer) not in self.pinned_state_dicts:
self.pinned_state_dicts[id(optimizer)] = create_pinned_state_dict(flatten_state_dict)
writer = move_and_save(
path=checkpoint,
state_dict=flatten_state_dict,
state_dict_pinned=self.pinned_state_dicts[id(optimizer)],
metadata=metadata,
)
self.async_writers.append(writer)
else:
save_state_dict(state_dict, checkpoint, use_safetensors=False)
def save_sharded_model(
self,
@ -186,7 +218,7 @@ class GeneralCheckpointIO(CheckpointIO):
if use_async:
pinned_state_dict = self.pinned_state_dicts.get(id(model), None)
total_size, new_pinned_state_dict, writers = async_save_state_dict_shards(
total_size, new_pinned_state_dict, writers = async_move_save_state_dict_shards(
sharded_state_dict=state_dict_shard,
checkpoint=checkpoint_path,
index_file=index_file,

View File

@ -22,6 +22,7 @@ from colossalai.tensor.padded_tensor import (
to_unpadded_tensor,
)
from colossalai.utils import get_current_device, get_non_persistent_buffers_set
from colossalai.utils.safetensors import _flatten_optim_state_dict, load_flat
from .general_checkpoint_io import GeneralCheckpointIO
from .index_file import CheckpointIndexFile
@ -69,6 +70,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
dp_group: ProcessGroup,
pp_group: ProcessGroup,
tp_group: ProcessGroup,
sp_group: ProcessGroup,
zero_stage: int,
verbose: bool = True,
) -> None:
@ -76,9 +78,11 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
self.global_dp_group = dp_group
self.pp_group = pp_group
self.tp_group = tp_group
self.sp_group = sp_group
self.dp_rank = dist.get_rank(self.global_dp_group)
self.tp_rank = dist.get_rank(self.tp_group)
self.pp_rank = dist.get_rank(self.pp_group)
self.sp_rank = dist.get_rank(self.sp_group)
self.global_dp_size = dist.get_world_size(dp_group)
self.pp_size = dist.get_world_size(pp_group)
self.tp_size = dist.get_world_size(tp_group)
@ -88,7 +92,11 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
@staticmethod
def _model_sharder(
model: nn.Module, prefix: str = "", keep_vars: bool = False, size_per_shard: int = 1024
model: nn.Module,
prefix: str = "",
keep_vars: bool = False,
size_per_shard: int = 1024,
pinned_state_dicts: Optional[Dict[str, torch.Tensor]] = None,
) -> Iterator[Tuple[OrderedDict, int]]:
# An internel method that breaks state_dict of model into shards within limited size.
@ -102,6 +110,11 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
if is_padded_tensor(param):
param = to_unpadded_tensor(param)
param_ = gather_distributed_param(param, keep_vars=False)
if pinned_state_dicts is not None:
if (prefix + name) not in pinned_state_dicts:
pinned_state_dicts[prefix + name] = torch.empty_like(param_, pin_memory=True, device="cpu")
pinned_state_dicts[prefix + name].copy_(param_)
param_ = pinned_state_dicts[prefix + name]
block, block_size = state_dict_sharder.append_param(prefix + name, param_)
if block is not None:
yield block, block_size
@ -111,6 +124,11 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
for name, buf in model.named_buffers():
if buf is not None and name not in non_persist_buffers_set:
buffer = buf if keep_vars else buf.detach()
if pinned_state_dicts is not None:
if (prefix + name) not in pinned_state_dicts:
pinned_state_dicts[prefix + name] = torch.empty_like(param_, pin_memory=True, device="cpu")
pinned_state_dicts[prefix + name].copy_(buffer)
buffer = pinned_state_dicts[prefix + name]
block, block_size = state_dict_sharder.append_param(prefix + name, buffer)
if block is not None:
yield block, block_size
@ -122,6 +140,11 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
is not torch.nn.Module.get_extra_state
):
extra_state = model.get_extra_state()
if pinned_state_dicts is not None:
if extra_state_key not in pinned_state_dicts:
pinned_state_dicts[extra_state_key] = torch.empty_like(param_, pin_memory=True, device="cpu")
pinned_state_dicts[extra_state_key].copy_(extra_state)
extra_state = pinned_state_dicts[extra_state_key]
block, block_size = state_dict_sharder.append_param(extra_state_key, extra_state)
if block is not None:
yield block, block_size
@ -136,6 +159,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
dp_group: ProcessGroup,
tp_group: ProcessGroup,
size_per_shard: int = 1024,
pinned_state_dicts: Optional[Dict[int, Dict[str, torch.Tensor]]] = None,
):
# An internel method that breaks state_dict of optimizer into shards within limited size.
@ -153,6 +177,9 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
working_param = param
param_id = param_info["param2id"][id(working_param)]
if pinned_state_dicts is not None:
if param_id not in pinned_state_dicts:
pinned_state_dicts[param_id] = {}
original_shape = param_info["param2shape"][id(working_param)]
state_ = HybridParallelCheckpointIO.gather_from_sharded_optimizer_state(
state,
@ -162,6 +189,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
tp_group=tp_group,
use_zero=use_zero,
inplace=False,
pinned_state_dicts=pinned_state_dicts[param_id] if pinned_state_dicts is not None else None,
)
block, block_size = state_dict_sharder.append_optim_state(param_id, state_)
@ -216,15 +244,31 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
# Then collect the sharded parameters & buffers along tp_group.
# Only devices with tp_rank == 0 are responsible for model saving.
state_dict_shard = HybridParallelCheckpointIO._model_sharder(model, size_per_shard=size_per_shard)
control_saving = self.tp_rank == 0 and self.sp_rank == 0
if control_saving and use_async:
if id(model) not in self.pinned_state_dicts:
self.pinned_state_dicts[id(model)] = {}
pinned_state_dicts = self.pinned_state_dicts[id(model)]
else:
pinned_state_dicts = None
state_dict_shard = HybridParallelCheckpointIO._model_sharder(
model, size_per_shard=size_per_shard, pinned_state_dicts=pinned_state_dicts
)
weights_name, save_index_file = get_model_base_filenames(prefix, use_safetensors)
index_file = CheckpointIndexFile(checkpoint)
control_saving = self.tp_rank == 0
if self.pp_size == 1:
# When pipeline is not used, save the model shards as in general checkpointIO
if use_async:
super().save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors, use_async=use_async)
total_size, writers = async_save_state_dict_shards(
sharded_state_dict=state_dict_shard,
checkpoint=checkpoint,
index_file=index_file,
base_filename=weights_name,
is_master=control_saving,
state_preprocess=False,
)
self.async_writers.extend(writers)
else:
total_size = save_state_dict_shards(
sharded_state_dict=state_dict_shard,
@ -259,15 +303,16 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
save_index_file = save_index_file.replace(".json", f"-stage-{self.pp_rank+1:05d}.json")
save_index_file = os.path.join("tmp_index_files", save_index_file)
if use_async:
total_size, returned_state_dict, writers = async_save_state_dict_shards(
total_size, writers = async_save_state_dict_shards(
sharded_state_dict=state_dict_shard,
checkpoint=checkpoint,
index_file=index_file,
base_filename=weights_name,
is_master=control_saving,
use_pp_format=True,
n_write_entries=191,
state_preprocess=False,
)
self.async_writers.extend(writers)
else:
total_size = save_state_dict_shards(
sharded_state_dict=state_dict_shard,
checkpoint=checkpoint,
@ -448,19 +493,39 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
# Then collect the sharded states along dp_group(if using zero)/tp_group.
# Only devices with (dp_rank == 0 and tp_rank == 0) are responsible for states saving.
control_saving = self.dp_rank == 0 and self.tp_rank == 0 and self.sp_rank == 0
if use_async and control_saving:
if id(optimizer) not in self.pinned_state_dicts:
self.pinned_state_dicts[id(optimizer)] = {}
pinned_state_dicts = self.pinned_state_dicts[id(optimizer)]
else:
pinned_state_dicts = None
state_dict_shard = HybridParallelCheckpointIO._optimizer_sharder(
optimizer,
use_zero=self.use_zero,
dp_group=self.global_dp_group,
tp_group=self.tp_group,
size_per_shard=size_per_shard,
pinned_state_dicts=pinned_state_dicts,
)
states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix)
states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix, use_safetensors=use_async)
index_file = CheckpointIndexFile(checkpoint)
control_saving = self.dp_rank == 0 and self.tp_rank == 0
if self.pp_size == 1:
# When pipeline is not used, save the optimizer shards as in general checkpointIO
if use_async:
total_size, writers = async_save_state_dict_shards(
sharded_state_dict=state_dict_shard,
checkpoint=checkpoint,
index_file=index_file,
base_filename=states_name,
is_master=control_saving,
use_pp_format=True,
state_preprocess=True,
)
self.async_writers.extend(writers)
else:
total_size = save_state_dict_shards(
sharded_state_dict=state_dict_shard,
checkpoint=checkpoint,
@ -498,10 +563,25 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
Path(tmp_index_file_folder).mkdir(parents=True, exist_ok=True)
# Manage filenames of sharded weights and index file for each pipeline stage.
if not use_async:
states_name = states_name.replace(".bin", f"-stage-{self.pp_rank+1:05d}-shard.bin")
else:
states_name = states_name.replace(".safetensors", f"-stage-{self.pp_rank+1:05d}-shard.safetensors")
save_index_file = save_index_file.replace(".json", f"-stage-{self.pp_rank+1:05d}.json")
save_index_file = os.path.join("tmp_index_files", save_index_file)
if use_async:
total_size, writers = async_save_state_dict_shards(
sharded_state_dict=state_dict_shard,
checkpoint=checkpoint,
index_file=index_file,
base_filename=states_name,
is_master=control_saving,
use_pp_format=True,
state_preprocess=True,
)
self.async_writers.extend(writers)
else:
total_size = save_state_dict_shards(
sharded_state_dict=state_dict_shard,
checkpoint=checkpoint,
@ -622,6 +702,9 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
continue
file_path = os.path.join(ckpt_root_path, filename)
if file_path.endswith(".safetensors"):
state_dict = load_flat(file_path)
else:
state_dict = load_shard_state_dict(Path(file_path), use_safetensors=False)
load_states_into_optimizer(optimizer.optim, state_dict, id_map, strict=True)
loaded_file.add(filename)
@ -672,7 +755,15 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
# When pipeline is not used, let master rank directly save the collected state_dict.
if self.tp_rank == 0:
if use_async:
super().save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors, use_async)
from colossalai.utils.safetensors import save
if id(model) not in self.pinned_state_dicts:
self.pinned_state_dicts[id(model)] = create_pinned_state_dict(state_dict)
for name, param in state_dict.items():
self.pinned_state_dicts[id(model)][name].copy_(param)
state_dict[name] = self.pinned_state_dicts[id(model)][name]
writer = save(path=checkpoint, state_dict=state_dict)
self.async_writers.append(writer)
else:
save_state_dict(state_dict, checkpoint, use_safetensors)
else:
@ -686,12 +777,14 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
for _state_dict in state_dict_list:
complete_state_dict.update(_state_dict)
if use_async:
from colossalai.utils.safetensors import move_and_save
from colossalai.utils.safetensors import save
if id(model) not in self.pinned_state_dicts:
self.pinned_state_dicts[id(model)] = create_pinned_state_dict(state_dict)
writer = move_and_save(checkpoint, state_dict, self.pinned_state_dicts[id(model)])
self.pinned_state_dicts[id(model)] = create_pinned_state_dict(complete_state_dict)
for name, param in complete_state_dict.items():
self.pinned_state_dicts[id(model)][name].copy_(param)
complete_state_dict[name] = self.pinned_state_dicts[id(model)][name]
writer = save(path=checkpoint, state_dict=complete_state_dict)
self.async_writers.append(writer)
else:
save_state_dict(complete_state_dict, checkpoint, use_safetensors)
@ -757,6 +850,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
# gather complete state from tp shards & dp shards
param_id = optimizer.param_info["param2id"][id(working_param)]
original_shape = optimizer.param_info["param2shape"][id(working_param)]
local_states[param_id] = HybridParallelCheckpointIO.gather_from_sharded_optimizer_state(
state,
working_param,
@ -776,6 +870,18 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
]
state_dict = {"param_groups": param_groups, "state": local_states}
if self.coordinator.is_master():
if use_async:
from colossalai.utils.safetensors import save
flatten_state_dict, metadata = _flatten_optim_state_dict(state_dict)
if id(optimizer) not in self.pinned_state_dicts:
self.pinned_state_dicts = create_pinned_state_dict(flatten_state_dict)
for k, v in flatten_state_dict.items():
self.pinned_state_dicts[k].copy_(v)
flatten_state_dict[k] = self.pinned_state_dicts[k]
writer = save(path=checkpoint, state_dict=flatten_state_dict, metadata=metadata)
self.async_writers.append(writer)
else:
save_state_dict(state_dict, checkpoint, use_safetensors=False)
else:
# When pipeline is used, first collect state_dict from every pipeline stage, then save the complete state_dict.
@ -792,6 +898,18 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
state_dict = {"param_groups": param_groups, "state": dict()}
for _states in states_list:
state_dict["state"].update(_states)
if use_async:
from colossalai.utils.safetensors import save
flatten_state_dict, metadata = _flatten_optim_state_dict(state_dict)
if id(optimizer) not in self.pinned_state_dicts:
self.pinned_state_dicts = create_pinned_state_dict(flatten_state_dict)
for k, v in flatten_state_dict.items():
self.pinned_state_dicts[k].copy_(v)
flatten_state_dict[k] = self.pinned_state_dicts[k]
writer = save(path=checkpoint, state_dict=flatten_state_dict, metadata=metadata)
self.async_writers.append(writer)
else:
save_state_dict(state_dict, checkpoint, use_safetensors=False)
def load_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str):
@ -818,6 +936,9 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before loading!"
# Complete optimizer state_dict loaded from checkpoint, need to be processed later.
if checkpoint.endswith(".safetensors"):
state_dict = load_flat(checkpoint)
else:
state_dict = load_state_dict(checkpoint)
# Load param_groups.
@ -872,6 +993,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
use_zero: bool,
inplace: bool,
device: torch.device = torch.device("cpu"),
pinned_state_dicts: Optional[Dict[str, torch.Tensor]] = None,
) -> OrderedDict:
"""
With given parameter and its optimizer states, gather the complete optimizer state for saving.
@ -895,6 +1017,8 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
state_ = state if inplace else copy.deepcopy(state)
for k, v in state_.items():
if v is None:
continue
if isinstance(v, torch.Tensor) and k != "step":
# First gather Zero shards.
if use_zero:
@ -915,6 +1039,12 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
v = init_as_padded_tensor(v, v.shape[padding_dim], original_shape[padding_dim], padding_dim)
v = to_unpadded_tensor(v)
if pinned_state_dicts is not None:
if k not in pinned_state_dicts:
pinned_state_dicts[k] = torch.empty_like(v, pin_memory=True, device="cpu")
pinned_state_dicts[k].copy_(v)
state_[k] = pinned_state_dicts[k]
else:
state_[k] = v.detach().clone().to(device)
return state_

View File

@ -44,12 +44,13 @@ class MoECheckpointIO(HybridParallelCheckpointIO):
global_dp_group: ProcessGroup,
pp_group: ProcessGroup,
tp_group: ProcessGroup,
sp_group: ProcessGroup,
ep_group: ProcessGroup,
moe_dp_group: ProcessGroup,
zero_stage: int,
verbose: bool = True,
) -> None:
super().__init__(global_dp_group, pp_group, tp_group, zero_stage, verbose)
super().__init__(global_dp_group, pp_group, tp_group, sp_group, zero_stage, verbose)
self.global_dp_group = global_dp_group
self.global_dp_rank = dist.get_rank(global_dp_group)
self.global_dp_size = dist.get_world_size(global_dp_group)
@ -158,7 +159,7 @@ class MoECheckpointIO(HybridParallelCheckpointIO):
state_dict_shard = MoECheckpointIO._model_sharder(model, size_per_shard=size_per_shard)
weights_name, save_index_file = get_model_base_filenames(prefix, use_safetensors)
index_file = CheckpointIndexFile(checkpoint)
control_saving = self.tp_rank == 0
control_saving = self.tp_rank == 0 and self.sp_rank == 0
if self.pp_size == 1 and self.ep_size == 1:
# When pipeline is not used, save the model shards as in general checkpointIO
@ -415,7 +416,7 @@ class MoECheckpointIO(HybridParallelCheckpointIO):
# e.g. dp_size = 4, moe_dp_size = 2, ep_size = 2 and use gather
# rank 0 saves moe & non-moe params; rank 1 only saves moe params
# rank 3 & 4 save nothing
control_saving = self.tp_rank == 0 and self.moe_dp_rank == 0
control_saving = self.tp_rank == 0 and self.moe_dp_rank == 0 and self.sp_rank == 0
if self.pp_size == 1 and self.ep_size == 1:
# When pipeline is not used, save the optimizer shards as in general checkpointIO

View File

@ -19,6 +19,7 @@ from colossalai.tensor.d_tensor import (
to_global,
to_global_for_customized_distributed_tensor,
)
from colossalai.utils.safetensors import _flatten_optim_state_dict
SAFE_WEIGHTS_NAME = "model.safetensors"
WEIGHTS_NAME = "pytorch_model.bin"
@ -266,6 +267,63 @@ def save_state_dict_shards(
def async_save_state_dict_shards(
sharded_state_dict: Iterator[Tuple[OrderedDict, int]],
checkpoint: str,
index_file: "CheckpointIndexFile",
base_filename: str,
is_master: bool,
use_pp_format: bool = False,
state_preprocess: bool = False,
) -> Tuple[int, list]:
"""
Save sharded state dict only on master rank, this method can be used by both model and optimizer states.
Args:
sharded_state_dict (Iterator[Tuple[OrderedDict, int]]): a generator of shards, each shard contains state dict and shard size.
checkpoint (str): The path of checkpoint directory as string.
index_file (CheckpointIndexFile): The index file object to be updated.
base_filename (str): Decides the prefix of filenames of shards.
is_master (bool): Whether current rank is main process.
use_safetensors (bool, optional): Whether to use safetensors to save checkpoint. Defaults to False.
use_pp_format: (bool, optional): Whether to save the files in pipeline format including stage information. Defaults to False.
Returns:
int: the total size of shards
"""
from colossalai.utils.safetensors import save
total_size = 0
shard_filenames = []
writers = []
for idx, shard_pair in enumerate(sharded_state_dict):
shard, current_size = shard_pair
# Just loop over the sharder and gather to other ranks if not master
if not is_master:
del shard
continue
shard_file = get_shard_filename(base_filename, idx)
total_size = total_size + current_size
for key in shard.keys():
index_file.append_weight_map(key, shard_file)
checkpoint_file_path = os.path.join(checkpoint, shard_file)
if state_preprocess:
state_dict, _ = _flatten_optim_state_dict(state_dict=shard, seperator=".")
else:
state_dict = shard
# Only save on master rank.
writer = save(checkpoint_file_path, state_dict=state_dict)
writers.append(writer)
shard_filenames.append(shard_file)
del shard
# Clean folder, deleted unneeded files.
clean_folder(checkpoint, base_filename, shard_filenames, is_master=is_master, use_pp_format=use_pp_format)
return total_size, writers
def async_move_save_state_dict_shards(
sharded_state_dict: Iterator[Tuple[OrderedDict, int]],
checkpoint: str,
index_file: "CheckpointIndexFile",
@ -273,6 +331,7 @@ def async_save_state_dict_shards(
is_master: bool,
pinned_state_dict: Optional[Dict[str, torch.Tensor]],
use_pp_format: bool = False,
state_preprocess: bool = False,
) -> Tuple[int, Dict[str, torch.Tensor], list]:
"""
Save sharded state dict only on master rank, this method can be used by both model and optimizer states.
@ -309,14 +368,19 @@ def async_save_state_dict_shards(
index_file.append_weight_map(key, shard_file)
checkpoint_file_path = os.path.join(checkpoint, shard_file)
if pinned_state_dict is not None:
sub_pinned_state_dict = {k: pinned_state_dict[k] for k in shard.keys()}
if state_preprocess:
state_dict, _ = _flatten_optim_state_dict(state_dict=shard)
else:
sub_pinned_state_dict = create_pinned_state_dict(shard)
state_dict = shard
if pinned_state_dict is not None:
sub_pinned_state_dict = {k: pinned_state_dict[k] for k in state_dict.keys()}
else:
sub_pinned_state_dict = create_pinned_state_dict(state_dict)
returned_state_dict.update(sub_pinned_state_dict)
# Only save on master rank.
writer = move_and_save(checkpoint_file_path, shard, sub_pinned_state_dict)
writer = move_and_save(checkpoint_file_path, state_dict, sub_pinned_state_dict)
writers.append(writer)
shard_filenames.append(shard_file)
del shard
@ -327,7 +391,11 @@ def async_save_state_dict_shards(
return total_size, returned_state_dict, writers
def shard_model_checkpoint(state_dict: torch.Tensor, max_shard_size: int = 1024) -> Iterator[Tuple[OrderedDict, int]]:
def shard_model_checkpoint(
state_dict: torch.Tensor,
max_shard_size: int = 1024,
pinned_state_dicts: Optional[Dict[int, Dict[str, torch.Tensor]]] = None,
) -> Iterator[Tuple[OrderedDict, int]]:
"""
Splits a model state dictionary in sub-checkpoints so that the final size of each sub-checkpoint does not exceed a
given size.
@ -336,6 +404,11 @@ def shard_model_checkpoint(state_dict: torch.Tensor, max_shard_size: int = 1024)
for key, weight in state_dict.items():
if not is_distributed_tensor(weight):
if pinned_state_dicts is not None:
if key not in pinned_state_dicts:
pinned_state_dicts[key] = torch.empty_like(weight, pin_memory=True, device="cpu")
pinned_state_dicts[key].copy_(weight)
weight = pinned_state_dicts[key]
block, block_size = state_dict_sharder.append_param(key, weight)
if block != None:
@ -345,7 +418,9 @@ def shard_model_checkpoint(state_dict: torch.Tensor, max_shard_size: int = 1024)
yield state_dict_sharder.current_block, state_dict_sharder.current_block_size
def shard_optimizer_checkpoint(state_dict: dict, max_shard_size: int = 1024) -> Iterator[Tuple[OrderedDict, int]]:
def shard_optimizer_checkpoint(
state_dict: dict, max_shard_size: int = 1024, pinned_state_dicts: Optional[Dict[str, torch.Tensor]] = None
) -> Iterator[Tuple[OrderedDict, int]]:
"""
Splits an optimizer state dictionary in sub-checkpoints so that the final size of each sub-checkpoint does not exceed a
given size.
@ -356,6 +431,15 @@ def shard_optimizer_checkpoint(state_dict: dict, max_shard_size: int = 1024) ->
state_dict_sharder = StateDictSharder(max_shard_size)
for param_id, state in states.items():
if pinned_state_dicts is not None:
if param_id not in pinned_state_dicts:
pinned_state_dicts[param_id] = {}
for k, v in state.items():
if k not in pinned_state_dicts[param_id]:
pinned_state_dicts[param_id][k] = torch.empty_like(v, pin_memory=True, device="cpu")
pinned_state_dicts[param_id][k].copy_(v)
state[k] = pinned_state_dicts[param_id][k]
block, block_size = state_dict_sharder.append_optim_state(param_id, state)
if block != None:
yield block, block_size

View File

@ -71,6 +71,8 @@ def _flatten_optim_state_dict(state_dict: dict, seperator: str = ".") -> Tuple[d
for idx, d in states.items():
for k, v in d.items():
if v is None:
continue
nested_key = f"state{seperator}{idx}{seperator}{k}"
if not isinstance(v, torch.Tensor):
non_tensor_keys.append(nested_key)
@ -87,7 +89,8 @@ def _flatten_optim_state_dict(state_dict: dict, seperator: str = ".") -> Tuple[d
def _unflatten_optim_state_dict(flat_dict: dict, metadata: Optional[dict] = None, seperator: str = "."):
state_dict = {}
if metadata is not None:
if metadata is not None and "non_tensor_keys" in metadata:
non_tensor_keys = json.loads(metadata["non_tensor_keys"])
else:
non_tensor_keys = []
@ -128,8 +131,10 @@ def prepare(
header = {}
offset = 0
header_metadata = {"format": "pt"}
if metadata is not None:
header["__metadata__"] = metadata
header_metadata.update(metadata)
header["__metadata__"] = header_metadata
for name, tensor in data.items():
n = tensor.numel() * tensor.element_size()
@ -172,8 +177,9 @@ def move_and_save(
path: str,
state_dict: Dict[str, torch.Tensor],
state_dict_pinned: Optional[Dict[str, torch.Tensor]] = None,
metadata: Optional[Dict[str, str]] = None,
) -> None:
prepared_data, _, tensor_keys = prepare(state_dict)
prepared_data, _, tensor_keys = prepare(state_dict, metadata)
n, header_bytes, _ = prepared_data.n, prepared_data.header_bytes, prepared_data.offset
f_writer = AsyncFileWriter(path, n_entries=ASYNC_WRITE_ENTRIES, backend="pthread", n_tasks=2 + len(tensor_keys))
f_writer.write(n.to_bytes(8, byteorder="little"))
@ -188,9 +194,9 @@ def move_and_save(
return f_writer
def load_flat(checkpoint_path):
def load_flat(checkpoint_path, seperator: str = "."):
with safe_open(checkpoint_path, framework="pt") as f:
metadata = f.metadata()
state_dict_load = load_file(checkpoint_path)
state_dict = _unflatten_optim_state_dict(state_dict_load, metadata)
state_dict = _unflatten_optim_state_dict(state_dict_load, metadata, seperator)
return state_dict

View File

@ -903,6 +903,7 @@ class GeminiDDP(ModelWrapper):
keep_vars: bool = False,
max_shard_size: int = 1024,
only_rank_0: bool = True,
pinned_state_dicts: Optional[Dict[str, torch.Tensor]] = None,
) -> Iterator[Tuple[OrderedDict, int]]:
"""Returns dictionaries containing a whole state of the module one by one. The max size of dictionary shard is specified by ``max_shard_size``.
@ -943,6 +944,13 @@ class GeminiDDP(ModelWrapper):
gathered_param_buffer.update(self._get_chunk_to_save_data(chunk, only_rank_0))
gathered_param = gathered_param_buffer.pop(param_to_save)
if pinned_state_dicts is not None:
if (prefix + name) not in pinned_state_dicts:
pinned_state_dicts[prefix + name] = torch.empty_like(
gathered_param, pin_memory=True, device="cpu"
)
pinned_state_dicts[prefix + name].copy_(gathered_param)
gathered_param = pinned_state_dicts[prefix + name]
block, block_size = sharder.append_param(prefix + name, gathered_param)
if block is not None:
yield block, block_size
@ -954,6 +962,11 @@ class GeminiDDP(ModelWrapper):
for name, buf in self.named_buffers():
if buf is not None and name not in self._non_persistent_buffers_set:
buffer = buf if keep_vars else buf.detach()
if pinned_state_dicts is not None:
if (prefix + name) not in pinned_state_dicts:
pinned_state_dicts[prefix + name] = torch.empty_like(buffer, pin_memory=True, device="cpu")
pinned_state_dicts[prefix + name].copy_(buffer)
buffer = pinned_state_dicts[prefix + name]
block, block_size = sharder.append_param(prefix + name, buffer)
if block is not None:
yield block, block_size
@ -964,6 +977,11 @@ class GeminiDDP(ModelWrapper):
is not torch.nn.Module.get_extra_state
):
extra_state = self.get_extra_state()
if pinned_state_dicts is not None:
if extra_state_key not in pinned_state_dicts:
pinned_state_dicts[extra_state_key] = torch.empty_like(extra_state, pin_memory=True, device="cpu")
pinned_state_dicts[extra_state_key].copy_(extra_state)
extra_state = pinned_state_dicts[extra_state_key]
block, block_size = sharder.append_param(extra_state_key, extra_state)
if block is not None:
yield block, block_size

View File

@ -809,7 +809,11 @@ class GeminiOptimizer(OptimizerWrapper):
self.optimizer_loading_epilogue()
def state_shard(
self, prefix: str = "", max_shard_size: int = 1024, only_rank_0: bool = True
self,
prefix: str = "",
max_shard_size: int = 1024,
only_rank_0: bool = True,
pinned_state_dicts: Optional[Dict[int, Dict[str, torch.Tensor]]] = None,
) -> Iterator[Tuple[OrderedDict, int]]:
"""Returns dictionaries containing shards of optimizer states one by one.
The max size of each dictionary shard is specified by ``max_shard_size``.
@ -829,6 +833,16 @@ class GeminiOptimizer(OptimizerWrapper):
dist.barrier()
state = self.collect_states(param_id=param_id, only_rank_0=only_rank_0)
if pinned_state_dicts is not None:
if param_id not in pinned_state_dicts:
pinned_state_dicts[param_id] = {}
for k, v in state.items():
if v is None:
continue
if k not in pinned_state_dicts[param_id]:
pinned_state_dicts[param_id][k] = torch.empty_like(v, pin_memory=True, device="cpu")
pinned_state_dicts[param_id][k].copy_(v)
state[k] = pinned_state_dicts[param_id][k]
block, block_size = sharder.append_optim_state(param_id, state)
if block is not None:
yield block, block_size

View File

@ -35,7 +35,10 @@ OPTIM_PLACEMENT_CONFIGS = [
@parameterize("use_safetensors", [False, True])
@parameterize("tp_size", [1, 2])
@parameterize("zero_size", [2])
def exam_state_dict_with_origin(placement_config, model_name, use_safetensors: bool, tp_size: int, zero_size: int):
@parameterize("use_async", [False, True])
def exam_state_dict_with_origin(
placement_config, model_name, use_safetensors: bool, tp_size: int, zero_size: int, use_async: bool
):
from transformers import BertForSequenceClassification
(model_fn, data_gen_fn, output_transform_fn, _, _) = next(iter(model_zoo.get_sub_registry(model_name).values()))
@ -70,7 +73,10 @@ def exam_state_dict_with_origin(placement_config, model_name, use_safetensors: b
"",
(model_size / 3),
use_safetensors=use_safetensors,
use_async=use_async,
)
booster.checkpoint_io._sync_d2h()
booster.checkpoint_io._sync_io()
dist.barrier()
new_bert_model = BertForSequenceClassification.from_pretrained(pretrained_path)
check_state_dict_equal(bert_model.state_dict(only_rank_0=False), new_bert_model.state_dict())
@ -83,7 +89,10 @@ def exam_state_dict_with_origin(placement_config, model_name, use_safetensors: b
@parameterize("size_per_shard", [32])
@parameterize("tp_size", [1, 2])
@parameterize("zero_size", [2])
def exam_state_dict(placement_config, shard: bool, model_name: str, size_per_shard: int, tp_size: int, zero_size: int):
@parameterize("use_async", [False, True])
def exam_state_dict(
placement_config, shard: bool, model_name: str, size_per_shard: int, tp_size: int, zero_size: int, use_async: bool
):
(model_fn, data_gen_fn, output_transform_fn, _, _) = next(iter(model_zoo.get_sub_registry(model_name).values()))
criterion = lambda x: x.mean()
enable_flash_attention = True if tp_size > 1 else False
@ -124,14 +133,18 @@ def exam_state_dict(placement_config, shard: bool, model_name: str, size_per_sha
with shared_tempdir() as tempdir:
model_ckpt_path = f"{tempdir}/model"
optimizer_ckpt_path = f"{tempdir}/optimizer"
booster.save_model(
model,
model_ckpt_path,
shard=shard,
size_per_shard=size_per_shard,
)
booster.save_optimizer(optimizer, optimizer_ckpt_path, shard=shard, size_per_shard=size_per_shard)
if not shard and use_async:
model_ckpt_path = f"{model_ckpt_path}.safetensors"
optimizer_ckpt_path = f"{optimizer_ckpt_path}.safetensors"
booster.save_model(model, model_ckpt_path, shard=shard, size_per_shard=size_per_shard, use_async=use_async)
booster.save_optimizer(
optimizer, optimizer_ckpt_path, shard=shard, size_per_shard=size_per_shard, use_async=use_async
)
booster.checkpoint_io._sync_d2h()
booster.checkpoint_io._sync_io()
dist.barrier()
booster.load_model(new_model, model_ckpt_path)
@ -155,8 +168,18 @@ def exam_state_dict(placement_config, shard: bool, model_name: str, size_per_sha
loss = criterion(output[output_key])
booster.backward(loss, new_optimizer)
new_optimizer.step()
booster.save_model(new_model, model_ckpt_path, shard=shard)
booster.save_optimizer(new_optimizer, optimizer_ckpt_path, shard=shard)
with shared_tempdir() as new_tempdir:
model_ckpt_path = f"{new_tempdir}/model"
optimizer_ckpt_path = f"{new_tempdir}/optimizer"
if not shard and use_async:
model_ckpt_path = f"{model_ckpt_path}.safetensors"
optimizer_ckpt_path = f"{optimizer_ckpt_path}.safetensors"
booster.save_model(new_model, model_ckpt_path, shard=shard, use_async=use_async)
booster.save_optimizer(new_optimizer, optimizer_ckpt_path, shard=shard, use_async=use_async)
booster.checkpoint_io._sync_d2h()
booster.checkpoint_io._sync_io()
def exam_lazy_from_pretrained():

View File

@ -19,7 +19,8 @@ from colossalai.testing import check_state_dict_equal, clear_cache_before_run, p
@clear_cache_before_run()
@parameterize("use_safetensors", [True, False])
def test_unsharded_checkpoint(use_safetensors: bool):
@parameterize("use_async", [False, True])
def test_unsharded_checkpoint(use_safetensors: bool, use_async: bool):
# create a model and optimizer
model = resnet18()
optimizer = Adam(model.parameters(), lr=0.001)
@ -36,18 +37,21 @@ def test_unsharded_checkpoint(use_safetensors: bool):
lr_scheduler.step()
# create a temp file for checkpoint
if use_safetensors:
if use_async or use_safetensors:
suffix = ".safetensors"
else:
suffix = ".bin"
model_ckpt_tempfile = tempfile.NamedTemporaryFile(suffix=suffix)
if use_async:
optimizer_ckpt_tempfile = tempfile.NamedTemporaryFile(suffix=suffix)
else:
optimizer_ckpt_tempfile = tempfile.NamedTemporaryFile()
lr_scheduler_ckpt_tempfile = tempfile.NamedTemporaryFile()
# save the model, optimizer, lr_scheduler
ckpt_io = GeneralCheckpointIO()
ckpt_io.save_model(model, model_ckpt_tempfile.name, use_safetensors=use_safetensors)
ckpt_io.save_optimizer(optimizer, optimizer_ckpt_tempfile.name)
ckpt_io.save_model(model, model_ckpt_tempfile.name, use_safetensors=use_safetensors, use_async=use_async)
ckpt_io.save_optimizer(optimizer, optimizer_ckpt_tempfile.name, use_async=use_async)
ckpt_io.save_lr_scheduler(lr_scheduler, lr_scheduler_ckpt_tempfile.name)
# create new model
@ -55,6 +59,9 @@ def test_unsharded_checkpoint(use_safetensors: bool):
new_optimizer = Adam(new_model.parameters(), lr=0.001)
new_lr_scheduler = CosineAnnealingWarmupLR(optimizer, total_steps=10)
ckpt_io._sync_d2h()
ckpt_io._sync_io()
# load the model, optimizer, lr_scheduler
ckpt_io.load_model(new_model, model_ckpt_tempfile.name)
ckpt_io.load_optimizer(new_optimizer, optimizer_ckpt_tempfile.name)
@ -66,7 +73,8 @@ def test_unsharded_checkpoint(use_safetensors: bool):
@pytest.mark.parametrize("use_safetensors", [True, False])
def test_sharded_model_checkpoint(use_safetensors: bool):
@pytest.mark.parametrize("use_async", [False, True])
def test_sharded_model_checkpoint(use_safetensors: bool, use_async: bool):
# create a model and optimizer
model = resnet18()
optimizer = Adam(model.parameters(), lr=0.001)
@ -79,21 +87,20 @@ def test_sharded_model_checkpoint(use_safetensors: bool):
loss.backward()
optimizer.step()
# create a temp file for checkpoint
if use_safetensors:
pass
else:
pass
model_ckpt_dir = tempfile.TemporaryDirectory()
optimizer_ckpt_tempfile = tempfile.NamedTemporaryFile()
# save the model and optimizer
ckpt_io = GeneralCheckpointIO()
ckpt_io.save_model(model, model_ckpt_dir.name, True, True, "", 10, use_safetensors=use_safetensors)
ckpt_io.save_model(
model, model_ckpt_dir.name, True, True, "", 10, use_safetensors=use_safetensors, use_async=use_async
)
ckpt_io.save_optimizer(optimizer, optimizer_ckpt_tempfile.name, shard=False)
ckpt_io._sync_d2h()
ckpt_io._sync_io()
# create new model
new_model = resnet18()
new_optimizer = Adam(new_model.parameters(), lr=0.001)
@ -106,7 +113,8 @@ def test_sharded_model_checkpoint(use_safetensors: bool):
check_state_dict_equal(optimizer.state_dict(), new_optimizer.state_dict())
def test_sharded_optimizer_checkpoint():
@pytest.mark.parametrize("use_async", [False, True])
def test_sharded_optimizer_checkpoint(use_async: bool):
# create a model and optimizer
model = resnet18()
optimizer = Adam(model.parameters(), lr=0.001)
@ -128,7 +136,10 @@ def test_sharded_optimizer_checkpoint():
ckpt_io = GeneralCheckpointIO()
ckpt_io.save_model(model, model_ckpt_dir.name, True, True, "", 10, use_safetensors=False)
ckpt_io.save_optimizer(optimizer, optimizer_ckpt_dir.name, shard=True, size_per_shard=10)
ckpt_io.save_optimizer(optimizer, optimizer_ckpt_dir.name, shard=True, size_per_shard=10, use_async=use_async)
ckpt_io._sync_d2h()
ckpt_io._sync_io()
# create new model
new_model = resnet18()
@ -148,9 +159,16 @@ def test_sharded_optimizer_checkpoint():
loss.backward()
new_optimizer.step()
# create temp directories for checkpoint
model_ckpt_dir = tempfile.TemporaryDirectory()
optimizer_ckpt_dir = tempfile.TemporaryDirectory()
# save the newly got optimizer
ckpt_io.save_model(new_model, model_ckpt_dir.name, True, True, "", 10, use_safetensors=False)
ckpt_io.save_optimizer(new_optimizer, optimizer_ckpt_dir.name, shard=True, size_per_shard=10)
ckpt_io.save_optimizer(new_optimizer, optimizer_ckpt_dir.name, shard=True, size_per_shard=10, use_async=use_async)
ckpt_io._sync_d2h()
ckpt_io._sync_io()
# create another new model
new_new_model = resnet18()
@ -164,7 +182,8 @@ def test_sharded_optimizer_checkpoint():
check_state_dict_equal(new_optimizer.state_dict(), new_new_optimizer.state_dict())
def test_sharded_optimizer_multiple_param_groups():
@pytest.mark.parametrize("use_async", [False, True])
def test_sharded_optimizer_multiple_param_groups(use_async: bool):
# create a model and optimizer
model = resnet18()
optimizer = Adam(
@ -188,7 +207,10 @@ def test_sharded_optimizer_multiple_param_groups():
ckpt_io = GeneralCheckpointIO()
ckpt_io.save_model(model, model_ckpt_dir.name, True, True, "", 10, use_safetensors=False)
ckpt_io.save_optimizer(optimizer, optimizer_ckpt_dir.name, shard=True, size_per_shard=10)
ckpt_io.save_optimizer(optimizer, optimizer_ckpt_dir.name, shard=True, size_per_shard=10, use_async=use_async)
ckpt_io._sync_d2h()
ckpt_io._sync_io()
# create new model
new_model = resnet18()

View File

@ -38,12 +38,13 @@ else:
]
@parameterize("shard", [True, False])
@parameterize("shard", [False, True])
@parameterize("model_name", ["transformers_llama_for_causal_lm"])
@parameterize("size_per_shard", [32])
@parameterize("test_config", TEST_CONFIGS)
@parameterize("use_async", [False, True])
@clear_cache_before_run()
def exam_state_dict(shard: bool, model_name: str, size_per_shard: int, test_config: dict):
def exam_state_dict(shard: bool, model_name: str, size_per_shard: int, test_config: dict, use_async: bool):
(model_fn, data_gen_fn, output_transform_fn, loss_fn, _) = next(
iter(model_zoo.get_sub_registry(model_name).values())
)
@ -85,8 +86,16 @@ def exam_state_dict(shard: bool, model_name: str, size_per_shard: int, test_conf
with shared_tempdir() as tempdir:
model_ckpt_path = f"{tempdir}/model"
optimizer_ckpt_path = f"{tempdir}/optimizer"
booster.save_model(model, model_ckpt_path, shard=shard, size_per_shard=size_per_shard)
booster.save_optimizer(optimizer, optimizer_ckpt_path, shard=shard, size_per_shard=size_per_shard)
if not shard and use_async:
model_ckpt_path = f"{model_ckpt_path}.safetensors"
optimizer_ckpt_path = f"{optimizer_ckpt_path}.safetensors"
booster.save_model(model, model_ckpt_path, shard=shard, size_per_shard=size_per_shard, use_async=use_async)
booster.save_optimizer(
optimizer, optimizer_ckpt_path, shard=shard, size_per_shard=size_per_shard, use_async=use_async
)
booster.checkpoint_io._sync_d2h()
booster.checkpoint_io._sync_io()
dist.barrier()
new_model = model_fn().cuda()

View File

@ -12,14 +12,15 @@ from colossalai.interface import OptimizerWrapper
from colossalai.testing import check_state_dict_equal, parameterize, rerun_if_address_is_in_use, spawn
@parameterize("shard", [True, False])
@parameterize("shard", [False, True])
@parameterize("size_per_shard", [16, 128])
def check_torch_ddp_checkpointIO(shard: bool, size_per_shard: int):
@parameterize("use_async", [False, True])
def check_torch_ddp_checkpointIO(shard: bool, size_per_shard: int, use_async: bool):
plugin = TorchDDPPlugin()
booster = Booster(plugin=plugin)
model = resnet18()
criterion = lambda x: x.mean()
optimizer = SGD((model.parameters()), lr=0.001)
optimizer = SGD((model.parameters()), lr=0.001, momentum=0.5)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.1)
model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion, lr_scheduler=scheduler)
@ -39,9 +40,18 @@ def check_torch_ddp_checkpointIO(shard: bool, size_per_shard: int):
model_ckpt_path = f"{tempdir}/model"
optimizer_ckpt_path = f"{tempdir}/optimizer"
lr_scheduler_ckpt_path = f"{tempdir}/lr_scheduler"
booster.save_model(model, model_ckpt_path, shard=shard, size_per_shard=size_per_shard)
booster.save_optimizer(optimizer, optimizer_ckpt_path, shard=shard, size_per_shard=size_per_shard)
if not shard and use_async:
model_ckpt_path = f"{model_ckpt_path}.safetensors"
optimizer_ckpt_path = f"{optimizer_ckpt_path}.safetensors"
booster.save_model(model, model_ckpt_path, shard=shard, size_per_shard=size_per_shard, use_async=use_async)
booster.save_optimizer(
optimizer, optimizer_ckpt_path, shard=shard, size_per_shard=size_per_shard, use_async=use_async
)
booster.save_lr_scheduler(scheduler, lr_scheduler_ckpt_path)
booster.checkpoint_io._sync_d2h()
booster.checkpoint_io._sync_io()
dist.barrier()
new_model = resnet18()

View File

@ -12,7 +12,7 @@ if version.parse(torch.__version__) >= version.parse("1.12.0"):
from colossalai.booster.plugin import TorchFSDPPlugin
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from colossalai.testing import rerun_if_address_is_in_use, spawn
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
def compare_nested_dict(dict1, dict2):
@ -43,7 +43,8 @@ def compare_nested_dict(dict1, dict2):
return True
def check_torch_fsdp_ckpt():
@parameterize("use_async", [False, True])
def check_torch_fsdp_ckpt(use_async: bool):
model = resnet18()
plugin = TorchFSDPPlugin()
booster = Booster(plugin=plugin)
@ -65,10 +66,17 @@ def check_torch_fsdp_ckpt():
model_ckpt_path = f"{tempdir}/model"
optim_ckpt_path = f"{tempdir}/optimizer"
if use_async:
model_ckpt_path = f"{model_ckpt_path}.safetensors"
optim_ckpt_path = f"{optim_ckpt_path}.safetensors"
run_model()
booster.save_model(fsdp_model, model_ckpt_path, shard=False)
booster.save_optimizer(optimizer, optim_ckpt_path, shard=False)
booster.save_model(fsdp_model, model_ckpt_path, shard=False, use_async=use_async)
booster.save_optimizer(optimizer, optim_ckpt_path, shard=False, use_async=use_async)
booster.checkpoint_io._sync_d2h()
booster.checkpoint_io._sync_io()
full_msd = fsdp_model.state_dict()
# full_osd = FSDP.full_optim_state_dict(fsdp_model, optimizer)
@ -106,8 +114,11 @@ def check_torch_fsdp_ckpt():
run_model()
booster.save_model(fsdp_model, model_ckpt_path, shard=True)
booster.save_optimizer(optimizer, optim_ckpt_path, shard=True)
booster.save_model(fsdp_model, model_ckpt_path, shard=True, use_async=use_async)
booster.save_optimizer(optimizer, optim_ckpt_path, shard=True, use_async=use_async)
booster.checkpoint_io._sync_d2h()
booster.checkpoint_io._sync_io()
full_msd = fsdp_model.unwrap().state_dict()
full_osd = FSDP.full_optim_state_dict(optimizer.unwrap_model().unwrap(), optim=optimizer)