mirror of https://github.com/hpcaitech/ColossalAI
[fsdp] impl save/load shard model/optimizer (#5357)
parent
d882d18c65
commit
bf34c6fef6
|
@ -1,3 +1,5 @@
|
|||
import logging
|
||||
import os
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from typing import Callable, Iterable, Iterator, List, Optional, Tuple
|
||||
|
@ -25,7 +27,7 @@ from torch.optim import Optimizer
|
|||
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from colossalai.checkpoint_io import CheckpointIO, GeneralCheckpointIO, utils
|
||||
from colossalai.checkpoint_io import CheckpointIO, GeneralCheckpointIO, utils, CheckpointIndexFile
|
||||
from colossalai.cluster import DistCoordinator
|
||||
from colossalai.interface import ModelWrapper, OptimizerWrapper
|
||||
|
||||
|
@ -74,17 +76,54 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO):
|
|||
|
||||
def save_sharded_model(
|
||||
self,
|
||||
model: nn.Module,
|
||||
checkpoint: str,
|
||||
gather_dtensor: bool,
|
||||
prefix: Optional[str],
|
||||
size_per_shard: int,
|
||||
use_safetensors: bool,
|
||||
model: ModelWrapper,
|
||||
checkpoint_path: str,
|
||||
gather_dtensor: bool = True,
|
||||
prefix: Optional[str] = None,
|
||||
size_per_shard: int = 1024,
|
||||
use_safetensors: bool = False,
|
||||
):
|
||||
"""
|
||||
Save model to checkpoint but only on master process.
|
||||
"""
|
||||
raise NotImplementedError("Sharded model checkpoint is not supported yet.")
|
||||
assert isinstance(model, TorchFSDPModel), "Please boost the model before saving!"
|
||||
if os.path.isfile(checkpoint_path):
|
||||
logging.error(f"Provided path ({checkpoint_path}) should be a directory, not a file")
|
||||
return
|
||||
|
||||
Path(checkpoint_path).mkdir(parents=True, exist_ok=True)
|
||||
with FSDP.state_dict_type(
|
||||
model.unwrap(),
|
||||
StateDictType.FULL_STATE_DICT,
|
||||
FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
|
||||
):
|
||||
state_dict = model.unwrap().state_dict()
|
||||
|
||||
state_dict_shard = utils.shard_model_checkpoint(state_dict, max_shard_size=size_per_shard)
|
||||
|
||||
weights_name, save_index_file = utils.get_model_base_filenames(prefix, use_safetensors)
|
||||
index_file = CheckpointIndexFile(checkpoint_path)
|
||||
|
||||
# In general cases, is_master is set to True to get the right behavior.
|
||||
total_size = utils.save_state_dict_shards(
|
||||
sharded_state_dict=state_dict_shard,
|
||||
checkpoint=checkpoint_path,
|
||||
index_file=index_file,
|
||||
base_filename=weights_name,
|
||||
is_master=self.coordinator.is_master(),
|
||||
use_safetensors=use_safetensors,
|
||||
)
|
||||
|
||||
# only save the index file on the master rank
|
||||
if self.coordinator.is_master():
|
||||
index_file.append_meta_data("total_size", total_size)
|
||||
index_file.write_index_file(save_index_file)
|
||||
utils.save_config_file(model.unwrap(), checkpoint_path)
|
||||
logging.info(
|
||||
f"The model is split into checkpoint shards. "
|
||||
f"You can find where each parameters has been saved in the "
|
||||
f"index located at {save_index_file}."
|
||||
)
|
||||
|
||||
def load_sharded_model(
|
||||
self,
|
||||
|
@ -97,7 +136,24 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO):
|
|||
"""
|
||||
Load model to checkpoint but only on master process.
|
||||
"""
|
||||
raise NotImplementedError("Sharded model checkpoint is not supported yet.")
|
||||
assert isinstance(model, TorchFSDPModel), "Please boost the model before loading!"
|
||||
use_safetensors = False
|
||||
if "safetensors" in checkpoint_index_file.name:
|
||||
use_safetensors = True
|
||||
|
||||
if use_safetensors and not utils.is_safetensors_available():
|
||||
raise ImportError("`safe_serialization` requires the `safetensors` library: `pip install safetensors`.")
|
||||
|
||||
# read checkpoint index file
|
||||
ckpt_index_file = CheckpointIndexFile.from_file(checkpoint_index_file)
|
||||
checkpoint_files, _ = ckpt_index_file.get_checkpoint_filenames()
|
||||
|
||||
fsdp_state_dict = {}
|
||||
for shard_file in checkpoint_files:
|
||||
fsdp_state_dict.update(utils.load_shard_state_dict(Path(shard_file), use_safetensors))
|
||||
|
||||
with FSDP.state_dict_type(model.unwrap(), StateDictType.FULL_STATE_DICT):
|
||||
model.unwrap().load_state_dict(fsdp_state_dict, strict=False)
|
||||
|
||||
def save_sharded_optimizer(
|
||||
self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool, prefix: str, size_per_shard: int
|
||||
|
@ -105,13 +161,86 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO):
|
|||
"""
|
||||
Save optimizer to checkpoint but only on master process.
|
||||
"""
|
||||
raise NotImplementedError("Sharded optimizer checkpoint is not supported yet.")
|
||||
assert isinstance(optimizer, FSDPOptimizerWrapper), "Please boost the optimizer before saving!"
|
||||
|
||||
if os.path.isfile(checkpoint):
|
||||
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
|
||||
return
|
||||
|
||||
Path(checkpoint).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
with FSDP.state_dict_type(
|
||||
optimizer.unwrap_model().unwrap(),
|
||||
StateDictType.FULL_STATE_DICT,
|
||||
FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
|
||||
):
|
||||
fsdp_optim_state = FSDP.full_optim_state_dict(
|
||||
optimizer.unwrap_model().unwrap(), optim=optimizer, rank0_only=True
|
||||
)
|
||||
|
||||
if self.coordinator.is_master():
|
||||
# Preparing file paths and index file.
|
||||
states_name, save_index_file, param_group_file = utils.get_optimizer_base_filenames(prefix)
|
||||
index_file = CheckpointIndexFile(checkpoint)
|
||||
|
||||
index_file.append_meta_data("param_groups", param_group_file)
|
||||
group_file_path = os.path.join(checkpoint, param_group_file)
|
||||
utils.save_param_groups(fsdp_optim_state, group_file_path)
|
||||
|
||||
sharded_state = utils.shard_optimizer_checkpoint(fsdp_optim_state, max_shard_size=size_per_shard)
|
||||
|
||||
# Save shards of optimizer states.
|
||||
# In general cases, is_master is set to True to get the right behavior.
|
||||
total_size = utils.save_state_dict_shards(
|
||||
sharded_state_dict=sharded_state,
|
||||
checkpoint=checkpoint,
|
||||
index_file=index_file,
|
||||
base_filename=states_name,
|
||||
is_master=self.coordinator.is_master(),
|
||||
use_safetensors=False,
|
||||
)
|
||||
|
||||
index_file.append_meta_data("total_size", total_size)
|
||||
index_file.write_index_file(save_index_file)
|
||||
logging.info(
|
||||
f"The optimizer is going to be split to checkpoint shards. "
|
||||
f"You can find where each parameters has been saved in the "
|
||||
f"index located at {save_index_file}."
|
||||
)
|
||||
|
||||
def load_sharded_optimizer(self, optimizer: Optimizer, index_file_path: str, size_per_shard: int):
|
||||
"""
|
||||
Load optimizer to checkpoint but only on master process.
|
||||
"""
|
||||
raise NotImplementedError("Sharded optimizer checkpoint is not supported yet.")
|
||||
assert isinstance(optimizer, FSDPOptimizerWrapper), "Please boost the optimizer before saving!"
|
||||
|
||||
ckpt_index_file = CheckpointIndexFile.from_file(index_file_path)
|
||||
|
||||
# Load param_groups
|
||||
param_group_path = ckpt_index_file.get_param_group_filename()
|
||||
if param_group_path is None:
|
||||
raise RuntimeError(
|
||||
f"Invalid index file path {index_file_path} for an optimizer. "
|
||||
"Looking param group file under current directory."
|
||||
)
|
||||
|
||||
saved_param_groups = torch.load(param_group_path)
|
||||
|
||||
# Load param
|
||||
fsdp_optim_state = {}
|
||||
checkpoint_files, _ = ckpt_index_file.get_checkpoint_filenames()
|
||||
for shard_file in checkpoint_files:
|
||||
state_dict_shard = utils.load_shard_state_dict(Path(shard_file), use_safetensors=False)
|
||||
fsdp_optim_state.update(state_dict_shard)
|
||||
|
||||
fsdp_optim_dict = dict(state=fsdp_optim_state, param_groups=saved_param_groups)
|
||||
|
||||
with FSDP.state_dict_type(optimizer.unwrap_model().unwrap(), StateDictType.FULL_STATE_DICT):
|
||||
fsdp_state = FSDP.optim_state_dict_to_load(
|
||||
model=optimizer.unwrap_model().unwrap(), optim=optimizer, optim_state_dict=fsdp_optim_dict
|
||||
)
|
||||
optimizer.load_state_dict(fsdp_state)
|
||||
|
||||
|
||||
def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str):
|
||||
"""
|
||||
|
@ -190,7 +319,7 @@ class TorchFSDPPlugin(DPPluginBase):
|
|||
raise RuntimeError("FSDP is not supported while torch version under 1.12.0.")
|
||||
|
||||
def support_no_sync(self) -> bool:
|
||||
False
|
||||
return False
|
||||
|
||||
def no_sync(self, model: nn.Module, optimizer: OptimizerWrapper) -> Iterator[None]:
|
||||
raise NotImplementedError("Torch fsdp no_sync func not supported yet.")
|
||||
|
|
|
@ -10,6 +10,7 @@ from colossalai.booster import Booster
|
|||
|
||||
if version.parse(torch.__version__) >= version.parse("1.12.0"):
|
||||
from colossalai.booster.plugin import TorchFSDPPlugin
|
||||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
||||
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
|
||||
|
@ -99,6 +100,43 @@ def check_torch_fsdp_ckpt():
|
|||
outputs_sec = fsdp_model(inputs)
|
||||
assert criterion(outputs_sec) == criterion(outputs)
|
||||
|
||||
with shared_tempdir() as tempdir:
|
||||
model_ckpt_path = f"{tempdir}/model"
|
||||
optim_ckpt_path = f"{tempdir}/optimizer"
|
||||
|
||||
run_model()
|
||||
|
||||
booster.save_model(fsdp_model, model_ckpt_path, shard=True)
|
||||
booster.save_optimizer(optimizer, optim_ckpt_path, shard=True)
|
||||
|
||||
full_msd = fsdp_model.unwrap().state_dict()
|
||||
full_osd = FSDP.full_optim_state_dict(optimizer.unwrap_model().unwrap(), optim=optimizer)
|
||||
|
||||
import copy
|
||||
sharded_osd = copy.deepcopy(full_osd)
|
||||
|
||||
run_model()
|
||||
|
||||
full_msd_updated = fsdp_model.unwrap().state_dict()
|
||||
full_osd_updated = FSDP.full_optim_state_dict(optimizer.unwrap_model().unwrap(), optim=optimizer)
|
||||
|
||||
# cost much time led to timeout
|
||||
# assert not compare_nested_dict(full_osd_updated, sharded_osd)
|
||||
# assert not compare_nested_dict(full_msd_updated, full_msd)
|
||||
outputs_first = fsdp_model(inputs)
|
||||
assert criterion(outputs_first) != criterion(outputs)
|
||||
|
||||
booster.load_model(fsdp_model, model_ckpt_path)
|
||||
booster.load_optimizer(optimizer, optim_ckpt_path)
|
||||
|
||||
full_msd_restore = fsdp_model.unwrap().state_dict()
|
||||
sharded_osd_restore = FSDP.full_optim_state_dict(optimizer.unwrap_model().unwrap(), optim=optimizer)
|
||||
|
||||
assert compare_nested_dict(sharded_osd, sharded_osd_restore)
|
||||
assert compare_nested_dict(full_msd_restore, full_msd)
|
||||
outputs_sec = fsdp_model(inputs)
|
||||
assert criterion(outputs_sec) == criterion(outputs)
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
# init dist env
|
||||
|
|
Loading…
Reference in New Issue