Browse Source

[fsdp] impl save/load shard model/optimizer (#5357)

pull/5388/head
QinLuo 9 months ago committed by GitHub
parent
commit
bf34c6fef6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 153
      colossalai/booster/plugin/torch_fsdp_plugin.py
  2. 38
      tests/test_checkpoint_io/test_torch_fsdp_checkpoint_io.py

153
colossalai/booster/plugin/torch_fsdp_plugin.py

@ -1,3 +1,5 @@
import logging
import os
import warnings import warnings
from pathlib import Path from pathlib import Path
from typing import Callable, Iterable, Iterator, List, Optional, Tuple from typing import Callable, Iterable, Iterator, List, Optional, Tuple
@ -25,7 +27,7 @@ from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from colossalai.checkpoint_io import CheckpointIO, GeneralCheckpointIO, utils from colossalai.checkpoint_io import CheckpointIO, GeneralCheckpointIO, utils, CheckpointIndexFile
from colossalai.cluster import DistCoordinator from colossalai.cluster import DistCoordinator
from colossalai.interface import ModelWrapper, OptimizerWrapper from colossalai.interface import ModelWrapper, OptimizerWrapper
@ -74,17 +76,54 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO):
def save_sharded_model( def save_sharded_model(
self, self,
model: nn.Module, model: ModelWrapper,
checkpoint: str, checkpoint_path: str,
gather_dtensor: bool, gather_dtensor: bool = True,
prefix: Optional[str], prefix: Optional[str] = None,
size_per_shard: int, size_per_shard: int = 1024,
use_safetensors: bool, use_safetensors: bool = False,
): ):
""" """
Save model to checkpoint but only on master process. Save model to checkpoint but only on master process.
""" """
raise NotImplementedError("Sharded model checkpoint is not supported yet.") assert isinstance(model, TorchFSDPModel), "Please boost the model before saving!"
if os.path.isfile(checkpoint_path):
logging.error(f"Provided path ({checkpoint_path}) should be a directory, not a file")
return
Path(checkpoint_path).mkdir(parents=True, exist_ok=True)
with FSDP.state_dict_type(
model.unwrap(),
StateDictType.FULL_STATE_DICT,
FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
):
state_dict = model.unwrap().state_dict()
state_dict_shard = utils.shard_model_checkpoint(state_dict, max_shard_size=size_per_shard)
weights_name, save_index_file = utils.get_model_base_filenames(prefix, use_safetensors)
index_file = CheckpointIndexFile(checkpoint_path)
# In general cases, is_master is set to True to get the right behavior.
total_size = utils.save_state_dict_shards(
sharded_state_dict=state_dict_shard,
checkpoint=checkpoint_path,
index_file=index_file,
base_filename=weights_name,
is_master=self.coordinator.is_master(),
use_safetensors=use_safetensors,
)
# only save the index file on the master rank
if self.coordinator.is_master():
index_file.append_meta_data("total_size", total_size)
index_file.write_index_file(save_index_file)
utils.save_config_file(model.unwrap(), checkpoint_path)
logging.info(
f"The model is split into checkpoint shards. "
f"You can find where each parameters has been saved in the "
f"index located at {save_index_file}."
)
def load_sharded_model( def load_sharded_model(
self, self,
@ -97,7 +136,24 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO):
""" """
Load model to checkpoint but only on master process. Load model to checkpoint but only on master process.
""" """
raise NotImplementedError("Sharded model checkpoint is not supported yet.") assert isinstance(model, TorchFSDPModel), "Please boost the model before loading!"
use_safetensors = False
if "safetensors" in checkpoint_index_file.name:
use_safetensors = True
if use_safetensors and not utils.is_safetensors_available():
raise ImportError("`safe_serialization` requires the `safetensors` library: `pip install safetensors`.")
# read checkpoint index file
ckpt_index_file = CheckpointIndexFile.from_file(checkpoint_index_file)
checkpoint_files, _ = ckpt_index_file.get_checkpoint_filenames()
fsdp_state_dict = {}
for shard_file in checkpoint_files:
fsdp_state_dict.update(utils.load_shard_state_dict(Path(shard_file), use_safetensors))
with FSDP.state_dict_type(model.unwrap(), StateDictType.FULL_STATE_DICT):
model.unwrap().load_state_dict(fsdp_state_dict, strict=False)
def save_sharded_optimizer( def save_sharded_optimizer(
self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool, prefix: str, size_per_shard: int self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool, prefix: str, size_per_shard: int
@ -105,13 +161,86 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO):
""" """
Save optimizer to checkpoint but only on master process. Save optimizer to checkpoint but only on master process.
""" """
raise NotImplementedError("Sharded optimizer checkpoint is not supported yet.") assert isinstance(optimizer, FSDPOptimizerWrapper), "Please boost the optimizer before saving!"
if os.path.isfile(checkpoint):
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
return
Path(checkpoint).mkdir(parents=True, exist_ok=True)
with FSDP.state_dict_type(
optimizer.unwrap_model().unwrap(),
StateDictType.FULL_STATE_DICT,
FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
):
fsdp_optim_state = FSDP.full_optim_state_dict(
optimizer.unwrap_model().unwrap(), optim=optimizer, rank0_only=True
)
if self.coordinator.is_master():
# Preparing file paths and index file.
states_name, save_index_file, param_group_file = utils.get_optimizer_base_filenames(prefix)
index_file = CheckpointIndexFile(checkpoint)
index_file.append_meta_data("param_groups", param_group_file)
group_file_path = os.path.join(checkpoint, param_group_file)
utils.save_param_groups(fsdp_optim_state, group_file_path)
sharded_state = utils.shard_optimizer_checkpoint(fsdp_optim_state, max_shard_size=size_per_shard)
# Save shards of optimizer states.
# In general cases, is_master is set to True to get the right behavior.
total_size = utils.save_state_dict_shards(
sharded_state_dict=sharded_state,
checkpoint=checkpoint,
index_file=index_file,
base_filename=states_name,
is_master=self.coordinator.is_master(),
use_safetensors=False,
)
index_file.append_meta_data("total_size", total_size)
index_file.write_index_file(save_index_file)
logging.info(
f"The optimizer is going to be split to checkpoint shards. "
f"You can find where each parameters has been saved in the "
f"index located at {save_index_file}."
)
def load_sharded_optimizer(self, optimizer: Optimizer, index_file_path: str, size_per_shard: int): def load_sharded_optimizer(self, optimizer: Optimizer, index_file_path: str, size_per_shard: int):
""" """
Load optimizer to checkpoint but only on master process. Load optimizer to checkpoint but only on master process.
""" """
raise NotImplementedError("Sharded optimizer checkpoint is not supported yet.") assert isinstance(optimizer, FSDPOptimizerWrapper), "Please boost the optimizer before saving!"
ckpt_index_file = CheckpointIndexFile.from_file(index_file_path)
# Load param_groups
param_group_path = ckpt_index_file.get_param_group_filename()
if param_group_path is None:
raise RuntimeError(
f"Invalid index file path {index_file_path} for an optimizer. "
"Looking param group file under current directory."
)
saved_param_groups = torch.load(param_group_path)
# Load param
fsdp_optim_state = {}
checkpoint_files, _ = ckpt_index_file.get_checkpoint_filenames()
for shard_file in checkpoint_files:
state_dict_shard = utils.load_shard_state_dict(Path(shard_file), use_safetensors=False)
fsdp_optim_state.update(state_dict_shard)
fsdp_optim_dict = dict(state=fsdp_optim_state, param_groups=saved_param_groups)
with FSDP.state_dict_type(optimizer.unwrap_model().unwrap(), StateDictType.FULL_STATE_DICT):
fsdp_state = FSDP.optim_state_dict_to_load(
model=optimizer.unwrap_model().unwrap(), optim=optimizer, optim_state_dict=fsdp_optim_dict
)
optimizer.load_state_dict(fsdp_state)
def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str): def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str):
""" """
@ -190,7 +319,7 @@ class TorchFSDPPlugin(DPPluginBase):
raise RuntimeError("FSDP is not supported while torch version under 1.12.0.") raise RuntimeError("FSDP is not supported while torch version under 1.12.0.")
def support_no_sync(self) -> bool: def support_no_sync(self) -> bool:
False return False
def no_sync(self, model: nn.Module, optimizer: OptimizerWrapper) -> Iterator[None]: def no_sync(self, model: nn.Module, optimizer: OptimizerWrapper) -> Iterator[None]:
raise NotImplementedError("Torch fsdp no_sync func not supported yet.") raise NotImplementedError("Torch fsdp no_sync func not supported yet.")

38
tests/test_checkpoint_io/test_torch_fsdp_checkpoint_io.py

@ -10,6 +10,7 @@ from colossalai.booster import Booster
if version.parse(torch.__version__) >= version.parse("1.12.0"): if version.parse(torch.__version__) >= version.parse("1.12.0"):
from colossalai.booster.plugin import TorchFSDPPlugin from colossalai.booster.plugin import TorchFSDPPlugin
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from colossalai.testing import rerun_if_address_is_in_use, spawn from colossalai.testing import rerun_if_address_is_in_use, spawn
@ -99,6 +100,43 @@ def check_torch_fsdp_ckpt():
outputs_sec = fsdp_model(inputs) outputs_sec = fsdp_model(inputs)
assert criterion(outputs_sec) == criterion(outputs) assert criterion(outputs_sec) == criterion(outputs)
with shared_tempdir() as tempdir:
model_ckpt_path = f"{tempdir}/model"
optim_ckpt_path = f"{tempdir}/optimizer"
run_model()
booster.save_model(fsdp_model, model_ckpt_path, shard=True)
booster.save_optimizer(optimizer, optim_ckpt_path, shard=True)
full_msd = fsdp_model.unwrap().state_dict()
full_osd = FSDP.full_optim_state_dict(optimizer.unwrap_model().unwrap(), optim=optimizer)
import copy
sharded_osd = copy.deepcopy(full_osd)
run_model()
full_msd_updated = fsdp_model.unwrap().state_dict()
full_osd_updated = FSDP.full_optim_state_dict(optimizer.unwrap_model().unwrap(), optim=optimizer)
# cost much time led to timeout
# assert not compare_nested_dict(full_osd_updated, sharded_osd)
# assert not compare_nested_dict(full_msd_updated, full_msd)
outputs_first = fsdp_model(inputs)
assert criterion(outputs_first) != criterion(outputs)
booster.load_model(fsdp_model, model_ckpt_path)
booster.load_optimizer(optimizer, optim_ckpt_path)
full_msd_restore = fsdp_model.unwrap().state_dict()
sharded_osd_restore = FSDP.full_optim_state_dict(optimizer.unwrap_model().unwrap(), optim=optimizer)
assert compare_nested_dict(sharded_osd, sharded_osd_restore)
assert compare_nested_dict(full_msd_restore, full_msd)
outputs_sec = fsdp_model(inputs)
assert criterion(outputs_sec) == criterion(outputs)
def run_dist(rank, world_size, port): def run_dist(rank, world_size, port):
# init dist env # init dist env

Loading…
Cancel
Save