[shardformer] support sharded checkpoint IO for models of HybridParallelPlugin (#4506)

* add APIs

* implement save_sharded_model

* add test for hybrid checkpointio

* implement naive loading for sharded model

* implement efficient sharded model loading

* open a new file for hybrid checkpoint_io

* small fix

* fix circular importing

* fix docstring

* arrange arguments and apis

* small fix
pull/4520/head^2
Baizhou Zhang 2023-08-25 22:04:57 +08:00 committed by GitHub
parent de8a65babc
commit 44eab2b27f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 497 additions and 40 deletions

View File

@ -16,7 +16,7 @@ from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler from torch.utils.data.distributed import DistributedSampler
from colossalai.amp.naive_amp.mixed_precision_optimizer import MixedPrecisionOptimizer from colossalai.amp.naive_amp.mixed_precision_optimizer import MixedPrecisionOptimizer
from colossalai.checkpoint_io import CheckpointIO from colossalai.checkpoint_io import CheckpointIO, HypridParallelCheckpointIO
from colossalai.cluster import ProcessGroupMesh from colossalai.cluster import ProcessGroupMesh
from colossalai.interface import ModelWrapper, OptimizerWrapper from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.pipeline.schedule import OneForwardOneBackwardSchedule from colossalai.pipeline.schedule import OneForwardOneBackwardSchedule
@ -292,6 +292,7 @@ class HybridParallelPlugin(PipelinePluginBase):
self.schedule = OneForwardOneBackwardSchedule(num_microbatches, self.stage_manager) self.schedule = OneForwardOneBackwardSchedule(num_microbatches, self.stage_manager)
self.tp_group = self.pg_mesh.get_group_along_axis(TP_AXIS) self.tp_group = self.pg_mesh.get_group_along_axis(TP_AXIS)
self.dp_group = self.pg_mesh.get_group_along_axis(DP_AXIS) self.dp_group = self.pg_mesh.get_group_along_axis(DP_AXIS)
self.pp_group = self.pg_mesh.get_group_along_axis(PP_AXIS)
self.shard_config = ShardConfig(tensor_parallel_process_group=self.tp_group, self.shard_config = ShardConfig(tensor_parallel_process_group=self.tp_group,
pipeline_stage_manager=self.stage_manager, pipeline_stage_manager=self.stage_manager,
enable_tensor_parallelism=self.tp_size > 1, enable_tensor_parallelism=self.tp_size > 1,
@ -460,7 +461,7 @@ class HybridParallelPlugin(PipelinePluginBase):
**_kwargs) **_kwargs)
def get_checkpoint_io(self) -> CheckpointIO: def get_checkpoint_io(self) -> CheckpointIO:
return None return HypridParallelCheckpointIO(self.dp_group, self.pp_group, self.tp_group)
def no_sync(self, model: Module) -> Iterator[None]: def no_sync(self, model: Module) -> Iterator[None]:
raise NotImplementedError raise NotImplementedError

View File

@ -1,5 +1,6 @@
from .checkpoint_io_base import CheckpointIO from .checkpoint_io_base import CheckpointIO
from .general_checkpoint_io import GeneralCheckpointIO from .general_checkpoint_io import GeneralCheckpointIO
from .hybrid_parallel_checkpoint_io import HypridParallelCheckpointIO
from .index_file import CheckpointIndexFile from .index_file import CheckpointIndexFile
__all__ = ['CheckpointIO', 'CheckpointIndexFile', 'GeneralCheckpointIO'] __all__ = ['CheckpointIO', 'CheckpointIndexFile', 'GeneralCheckpointIO', 'HybridParallelCheckpointIO']

View File

@ -0,0 +1,316 @@
import copy
import gc
import logging
import os
from pathlib import Path
from shutil import rmtree
from typing import Any, Callable, Iterator, List, Optional, OrderedDict, Tuple, Union
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.distributed import ProcessGroup
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from colossalai.cluster import ProcessGroupMesh
from colossalai.tensor.d_tensor import (
is_customized_distributed_tensor,
is_distributed_tensor,
to_global,
to_global_for_customized_distributed_tensor,
)
from .general_checkpoint_io import GeneralCheckpointIO
from .index_file import CheckpointIndexFile
from .utils import (
StateDictSharder,
calculate_tensor_size,
gather_distributed_param,
get_model_base_filenames,
get_optimizer_base_filenames,
get_shard_filename,
is_safetensors_available,
load_shard_state_dict,
load_state_dict_into_model,
save_param_groups,
save_state_dict,
save_state_dict_shards,
)
try:
from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX, _IncompatibleKeys
except ImportError:
_EXTRA_STATE_KEY_SUFFIX = '_extra_state'
class HypridParallelCheckpointIO(GeneralCheckpointIO):
"""
CheckpointIO for Hybrid Parallel Training.
Args:
dp_group (ProcessGroup): Process group along data parallel dimension.
pp_group (ProcessGroup): Process group along pipeline parallel dimension.
tp_group (ProcessGroup): Process group along tensor parallel dimension.
"""
def __init__(self, dp_group: ProcessGroup, pp_group: ProcessGroup, tp_group: ProcessGroup) -> None:
super().__init__()
self.dp_group = dp_group
self.pp_group = pp_group
self.tp_group = tp_group
self.dp_rank = dist.get_rank(self.dp_group)
self.tp_rank = dist.get_rank(self.tp_group)
self.pp_rank = dist.get_rank(self.pp_group)
self.dp_size = dist.get_world_size(dp_group)
self.pp_size = dist.get_world_size(pp_group)
self.tp_size = dist.get_world_size(tp_group)
@staticmethod
def _model_sharder(model: nn.Module,
prefix: str = '',
keep_vars: bool = False,
size_per_shard: int = 1024) -> Iterator[Tuple[OrderedDict, int]]:
# An internel method that breaks state_dict of model into shards within limited size.
state_dict_sharder = StateDictSharder(size_per_shard)
# Save parameters.
for name, param in model.named_parameters():
if param is None:
continue
# Gather tensor pieces when using tensor parallel.
param_ = gather_distributed_param(param, keep_vars=False)
block, block_size = state_dict_sharder.append(prefix + name, param_)
if block is not None:
yield block, block_size
# Save buffers.
for name, buf in model.named_buffers():
if buf is not None and name not in model._non_persistent_buffers_set:
buffer = buf if keep_vars else buf.detach()
block, block_size = state_dict_sharder.append(prefix + name, buffer)
if block is not None:
yield block, block_size
# Save extra states.
extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX
if getattr(model.__class__, "get_extra_state",
torch.nn.Module.get_extra_state) is not torch.nn.Module.get_extra_state:
extra_state = model.get_extra_state()
block, block_size = state_dict_sharder.append(extra_state_key, extra_state)
if block is not None:
yield block, block_size
# Return the last block in sharder.
yield state_dict_sharder.current_block, state_dict_sharder.current_block_size
@staticmethod
def _optimizer_sharder(optimizer: Optimizer, size_per_shard: int = 1024):
# An internel method that breaks state_dict of optimizer into shards within limited size.
# TODO (Baizhou): Implement sharding feature of optimizer.
pass
def save_sharded_model(self,
model: nn.Module,
checkpoint: str,
gather_dtensor: bool = True,
prefix: Optional[str] = None,
size_per_shard: int = 1024,
use_safetensors: bool = False) -> None:
"""
Save sharded model checkpoint under the given checkpointing path.
The following files will be created under the path:
- An index file (pytorch_model.bin.index.json) containing a map between model params/buffers and file names.
- Multiple files that store state tensors of models.
If pipeline parallelism is used, the filenames are in the form of "pytorch_model.<prefix>-stage-000XX-shard-000XX.bin".
If pipeline parallelism is not used, "pytorch_model.<prefix>-000XX.bin"
Args:
model (nn.Module): Model on local device to be saved.
checkpoint (str): Checkpointing path which should be a directory path.
gather_dtensor (bool, optional): Whether to gather_dtensor, currently not used. Defaults to True.
prefix (str, optional): Perfix of file to save. Defaults to None.
size_per_shard (int, optional): Size per shard in MB. Defaults to 1024.
use_safetensors (bool, optional): Whether to use safe tensors. Defaults to False.
"""
if os.path.isfile(checkpoint):
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
return
Path(checkpoint).mkdir(parents=True, exist_ok=True)
# Devices along the same dp_group share the same copies of model.
# So only let the device with dp_rank == 0 save the model.
if self.dp_rank != 0:
return
# Then collect the sharded parameters & buffers along tp_group.
# Only devices with tp_size == 0 are responsible for model saving.
state_dict_shard = HypridParallelCheckpointIO._model_sharder(model, size_per_shard=size_per_shard)
weights_name, save_index_file = get_model_base_filenames(prefix, use_safetensors)
index_file = CheckpointIndexFile(checkpoint)
control_saving = (self.tp_rank == 0)
if self.pp_size == 1:
# When pipeline is not used, save the model shards as in general checkpointIO
total_size = save_state_dict_shards(sharded_state_dict=state_dict_shard,
checkpoint=checkpoint,
index_file=index_file,
base_filename=weights_name,
is_master=control_saving,
use_safetensors=use_safetensors)
if control_saving:
index_file.append_meta_data("total_size", total_size)
index_file.write_index_file(save_index_file)
logging.info(f"The model is split into checkpoint shards. "
f"You can find where each parameters has been saved in the "
f"index located at {save_index_file}.")
else:
# When pipeline is used, each stage produces its own shard files and index files.
# Index files belonging to each stage are saved under a temporary folder ./tmp_index_files/
# After all the state_dicts have been saved, the master rank integrates all the index files into one final index file and deletes the tmp folder.
final_index_file_path = copy.deepcopy(save_index_file)
tmp_index_file_folder = os.path.join(checkpoint, "tmp_index_files")
Path(tmp_index_file_folder).mkdir(parents=True, exist_ok=True)
# Manage filenames of sharded weights and index file for each pipeline stage.
weights_name = weights_name.replace(".bin", f"-stage-{self.pp_rank:05d}-shard.bin")
weights_name = weights_name.replace(".safetensors", f"-stage-{self.pp_rank:05d}-shard.safetensors")
save_index_file = save_index_file.replace(".json", f"-stage-{self.pp_rank:05d}.json")
save_index_file = os.path.join("tmp_index_files", save_index_file)
total_size = save_state_dict_shards(sharded_state_dict=state_dict_shard,
checkpoint=checkpoint,
index_file=index_file,
base_filename=weights_name,
is_master=control_saving,
use_safetensors=use_safetensors)
if control_saving:
assert self.dp_rank == 0 and self.tp_rank == 0, "The saving process should have both dp_rank and tp_rank as 0."
index_file.append_meta_data("total_size", total_size)
index_file.write_index_file(save_index_file)
else:
return
dist.barrier(self.pp_group)
# The global master rank integrates the index files and clean the folder.
if self.pp_rank == 0:
final_index_file = CheckpointIndexFile(checkpoint)
final_index_file.append_meta_data("total_size", 0)
for filename in os.listdir(tmp_index_file_folder):
stage_index_file = CheckpointIndexFile.from_file(os.path.join(tmp_index_file_folder, filename))
final_index_file.metadata["total_size"] += stage_index_file.metadata["total_size"]
for weight, weight_filename in stage_index_file.weight_map.items():
final_index_file.append_weight_map(weight, weight_filename)
final_index_file.write_index_file(final_index_file_path)
rmtree(tmp_index_file_folder)
logging.info(f"The model is split into checkpoint shards. "
f"You can find where each parameters has been saved in the "
f"index located at {final_index_file_path}.")
def load_sharded_model(self, model: nn.Module, checkpoint_index_file: Path, strict: bool = False):
"""
Load sharded model with the given path to index file of checkpoint folder.
Args:
model (nn.Module): The model to be loaded.
index_file_path (str): Path to the index file of checkpointing folder.
strict (bool, optional): For name matching during loading state_dict. Defaults to False.
This argument should be manually set to False since params on same device might be stored in different files.
"""
# Check whether the checkpoint uses safetensors.
use_safetensors = False
if "safetensors" in checkpoint_index_file.name:
use_safetensors = True
if use_safetensors and not is_safetensors_available():
raise ImportError("`safe_serialization` requires the `safetensors` library: `pip install safetensors`.")
# Read checkpoint index file.
ckpt_index_file = CheckpointIndexFile.from_file(checkpoint_index_file)
ckpt_root_path = ckpt_index_file.root_path
weight_map = ckpt_index_file.weight_map
strict = False
# Load params & buffers to model.
# Keep a record of loaded files so that file will not be repeatedly loaded.
loaded_file = set()
def _load(name: str):
if name not in weight_map:
raise ValueError(f"{name} is not stored in checkpoint, please check your checkpointing configuration!")
filename = weight_map[name]
# If this param/buffer has been loaded before, directly return.
if filename in loaded_file:
return
file_path = os.path.join(ckpt_root_path, filename)
state_dict = load_shard_state_dict(Path(file_path), use_safetensors)
missing_keys = []
load_state_dict_into_model(model,
state_dict,
missing_keys=missing_keys,
strict=strict,
load_sub_module=True)
del state_dict
loaded_file.add(filename)
# Load parameters.
for name, _ in model.named_parameters():
_load(name)
# Load buffers.
for name, buf in model.named_buffers():
if buf is not None and name not in model._non_persistent_buffers_set:
_load(name)
# Load extra states.
extra_state_key = _EXTRA_STATE_KEY_SUFFIX
if getattr(model.__class__, "get_extra_state",
torch.nn.Module.get_extra_state) is not torch.nn.Module.get_extra_state:
_load(extra_state_key)
def save_sharded_optimizer(self,
optimizer: Optimizer,
checkpoint: str,
gather_dtensor: bool = True,
prefix: Optional[str] = None,
size_per_shard: int = 1024):
pass
def load_sharded_optimizer(self, optimizer: Optimizer, index_file_path: str, prefix: str):
pass
def load_unsharded_model(self, model: nn.Module, checkpoint: str, strict: bool = True):
# TODO(Baizhou): support this feature after implementing complete state_dict collection
raise NotImplementedError
def save_unsharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, use_safetensors: bool):
# TODO(Baizhou): support this feature after implementing complete state_dict collection
raise NotImplementedError
def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool):
# TODO(Baizhou): support this feature after implementing complete state_dict collection
raise NotImplementedError
def load_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool):
# TODO(Baizhou): support this feature after implementing complete state_dict collection
raise NotImplementedError
def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str):
"""
Save lr scheduler to checkpoint but only on master process.
"""
if self.coordinator.is_master():
super().save_lr_scheduler(lr_scheduler, checkpoint)

View File

@ -13,7 +13,12 @@ from torch.optim import Optimizer
from colossalai.interface import OptimizerWrapper from colossalai.interface import OptimizerWrapper
from colossalai.nn.optimizer import ColossalaiOptimizer from colossalai.nn.optimizer import ColossalaiOptimizer
from colossalai.tensor.d_tensor import is_distributed_tensor from colossalai.tensor.d_tensor import (
is_customized_distributed_tensor,
is_distributed_tensor,
to_global,
to_global_for_customized_distributed_tensor,
)
SAFE_WEIGHTS_NAME = "model.safetensors" SAFE_WEIGHTS_NAME = "model.safetensors"
WEIGHTS_NAME = "pytorch_model.bin" WEIGHTS_NAME = "pytorch_model.bin"
@ -88,8 +93,28 @@ def is_safetensor_checkpoint(checkpoint_file_path: str) -> bool:
return False return False
def gather_distributed_param(param: torch.Tensor, keep_vars: bool = False):
"""
Gather the complete parameter for saving if passed in param is distributed.
Args:
param (torch.Tensor): A model parameter, might be d_tensor.
keep_vars (bool, optional): Whether to return the parameter in calculation graph. Defaults to False.
Returns:
torch.Tensor: the complete parameter
"""
param_ = param if keep_vars else param.detach()
if is_distributed_tensor(param_):
return to_global(param_)
elif is_customized_distributed_tensor(param_):
return to_global_for_customized_distributed_tensor(param_)
else:
return param_
# ====================================== # ======================================
# Helper functions for saving shard file # Helper classes and functions for saving shard file
# ====================================== # ======================================
def unwrap_optimizer(optimizer: OptimizerWrapper): def unwrap_optimizer(optimizer: OptimizerWrapper):
''' '''
@ -104,6 +129,31 @@ def unwrap_optimizer(optimizer: OptimizerWrapper):
return unwrapped_optim return unwrapped_optim
class StateDictSharder:
def __init__(self, size_per_shard: int) -> None:
self.max_shard_size = size_per_shard
self.current_block = OrderedDict()
self.current_block_size = 0
def append(self, name: str, tensor: torch.Tensor) -> Tuple[Optional[OrderedDict], int]:
tensor_size = calculate_tensor_size(tensor)
ret_block = None
ret_block_size = 0
# before we return the current block and create a new block,
# we need to ensure that the current block is not empty
if self.current_block_size + tensor_size > self.max_shard_size and self.current_block_size > 0:
ret_block = self.current_block
ret_block_size = self.current_block_size
self.current_block = OrderedDict()
self.current_block_size = 0
self.current_block[name] = tensor
self.current_block_size += tensor_size
return ret_block, ret_block_size
def save_state_dict_shards(sharded_state_dict: Iterator[Tuple[OrderedDict, int]], def save_state_dict_shards(sharded_state_dict: Iterator[Tuple[OrderedDict, int]],
checkpoint: str, checkpoint: str,
index_file: "CheckpointIndexFile", index_file: "CheckpointIndexFile",
@ -126,9 +176,10 @@ def save_state_dict_shards(sharded_state_dict: Iterator[Tuple[OrderedDict, int]]
total_size = 0 total_size = 0
for idx, shard_pair in enumerate(sharded_state_dict): for idx, shard_pair in enumerate(sharded_state_dict):
if not is_master:
continue
shard, current_size = shard_pair shard, current_size = shard_pair
if not is_master:
del shard
continue
shard_file = get_shard_filename(base_filename, idx) shard_file = get_shard_filename(base_filename, idx)
total_size = total_size + current_size total_size = total_size + current_size
for key in shard.keys(): for key in shard.keys():
@ -137,6 +188,7 @@ def save_state_dict_shards(sharded_state_dict: Iterator[Tuple[OrderedDict, int]]
# Only save on master rank. # Only save on master rank.
save_state_dict(shard, checkpoint_file_path, use_safetensors=use_safetensors) save_state_dict(shard, checkpoint_file_path, use_safetensors=use_safetensors)
del shard
return total_size return total_size

View File

@ -10,6 +10,7 @@ import torch.nn as nn
from torch.distributed import ProcessGroup from torch.distributed import ProcessGroup
from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX, Module from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX, Module
from colossalai.checkpoint_io.utils import gather_distributed_param
from colossalai.tensor.d_tensor import ( from colossalai.tensor.d_tensor import (
distribute_tensor, distribute_tensor,
distribute_tensor_with_customization, distribute_tensor_with_customization,
@ -56,13 +57,7 @@ class ParallelModule(nn.Module, ABC):
""" """
for name, param in self._parameters.items(): for name, param in self._parameters.items():
if param is not None: if param is not None:
param_ = param if keep_vars else param.detach() destination[prefix + name] = gather_distributed_param(param, keep_vars=keep_vars)
if is_distributed_tensor(param_):
destination[prefix + name] = to_global(param_)
elif is_customized_distributed_tensor(param_):
destination[prefix + name] = to_global_for_customized_distributed_tensor(param_)
else:
destination[prefix + name] = param_
for name, buf in self._buffers.items(): for name, buf in self._buffers.items():
if buf is not None and name not in self._non_persistent_buffers_set: if buf is not None and name not in self._non_persistent_buffers_set:

View File

@ -8,7 +8,7 @@ import torch
import torch.distributed as dist import torch.distributed as dist
import torch.nn as nn import torch.nn as nn
from colossalai.checkpoint_io.utils import calculate_tensor_size from colossalai.checkpoint_io.utils import StateDictSharder
from colossalai.lazy import LazyTensor from colossalai.lazy import LazyTensor
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.nn.parallel.data_parallel import ColoDDP, _cast_float, free_storage from colossalai.nn.parallel.data_parallel import ColoDDP, _cast_float, free_storage
@ -657,7 +657,7 @@ class ZeroDDP(ColoDDP):
Yields: Yields:
Iterator[OrderedDict]: A generator of state dict shard Iterator[OrderedDict]: A generator of state dict shard
""" """
sharder = _StateDictSharder(max_shard_size) sharder = StateDictSharder(max_shard_size)
# get the mapping between copies and fp16 parameters # get the mapping between copies and fp16 parameters
fp16_to_fp32 = dict() fp16_to_fp32 = dict()
@ -705,30 +705,6 @@ class ZeroDDP(ColoDDP):
yield sharder.current_block, sharder.current_block_size yield sharder.current_block, sharder.current_block_size
class _StateDictSharder:
def __init__(self, max_shard_size: int) -> None:
self.max_shard_size = max_shard_size
self.current_block = OrderedDict()
self.current_block_size = 0
def append(self, name: str, tensor: torch.Tensor) -> Tuple[Optional[OrderedDict], int]:
tensor_size = calculate_tensor_size(tensor)
ret_block = None
ret_block_size = 0
# before we return the current block and create a new block,
# we need to ensure that the current block is not empty
if self.current_block_size + tensor_size > self.max_shard_size and self.current_block_size > 0:
ret_block = self.current_block
ret_block_size = self.current_block_size
self.current_block = OrderedDict()
self.current_block_size = 0
self.current_block[name] = tensor
self.current_block_size += tensor_size
return ret_block, ret_block_size
class GeminiDDP(ZeroDDP): class GeminiDDP(ZeroDDP):
def __init__(self, def __init__(self,

View File

@ -0,0 +1,116 @@
import pytest
import torch
import torch.distributed as dist
from torch.optim import Adam
from utils import shared_tempdir
import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin import HybridParallelPlugin
from colossalai.tensor.d_tensor.api import clear_layout_converter
from colossalai.testing import (
check_state_dict_equal,
clear_cache_before_run,
parameterize,
rerun_if_address_is_in_use,
spawn,
)
from tests.kit.model_zoo import model_zoo
@clear_cache_before_run()
@parameterize('shard', [True])
@parameterize('model_name', ['transformers_gpt'])
@parameterize('size_per_shard', [32])
@parameterize('test_config', [{
'tp_size': 2,
'pp_size': 2,
'num_microbatches': 4,
'precision': 'fp32',
}, {
'tp_size': 1,
'pp_size': 2,
'num_microbatches': 4,
'precision': 'fp32',
}, {
'tp_size': 4,
'pp_size': 1,
'precision': 'fp32',
}, {
'tp_size': 2,
'pp_size': 1,
'precision': 'fp32',
}, {
'tp_size': 2,
'pp_size': 1,
'zero_stage': 2,
'precision': 'fp16',
'initial_scale': 1
}])
def exam_state_dict(shard: bool, model_name: str, size_per_shard: int, test_config: dict):
(model_fn, data_gen_fn, output_transform_fn, loss_fn,
_) = next(iter(model_zoo.get_sub_registry(model_name).values()))
criterion = loss_fn
plugin = HybridParallelPlugin(**test_config)
booster = Booster(plugin=plugin)
def _criterion(outputs, inputs):
outputs = output_transform_fn(outputs)
loss = criterion(outputs)
return loss
model = model_fn().cuda()
optimizer = Adam(model.parameters(), lr=1e-3)
model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion)
new_model = model_fn().cuda()
new_optimizer = Adam(new_model.parameters(), lr=1e-3)
new_model, new_optimizer, criterion, _, _ = booster.boost(new_model, new_optimizer, criterion)
data = data_gen_fn()
model.train()
if booster.plugin.stage_manager is not None:
for k, v in data.items():
if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__:
new_shape = [1] * v.dim()
new_shape[0] = 4
data[k] = v.to('cuda').repeat(*new_shape)
data_iter = iter([data])
output = booster.execute_pipeline(data_iter,
model,
_criterion,
optimizer,
return_loss=True,
return_outputs=False)
else:
data = {k: v.cuda() for k, v in data.items()}
output = model(**data)
loss = criterion(output)
optimizer.backward(loss)
optimizer.step()
with shared_tempdir() as tempdir:
model_ckpt_path = f"{tempdir}/model"
# optimizer_ckpt_path = f"{tempdir}/optimizer"
booster.save_model(model, model_ckpt_path, shard=shard, size_per_shard=size_per_shard)
# booster.save_optimizer(optimizer, optimizer_ckpt_path, shard=shard, size_per_shard=size_per_shard)
dist.barrier()
booster.load_model(new_model, model_ckpt_path)
check_state_dict_equal(model.unwrap().state_dict(), new_model.unwrap().state_dict(), False)
clear_layout_converter()
def run_dist(rank, world_size, port):
config = {}
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
exam_state_dict()
@pytest.mark.dist
@pytest.mark.parametrize('world_size', [4])
@rerun_if_address_is_in_use()
def test_hybrid_ckpIO(world_size):
spawn(run_dist, world_size)