mirror of https://github.com/hpcaitech/ColossalAI
[shardformer] support sharded checkpoint IO for models of HybridParallelPlugin (#4506)
* add APIs * implement save_sharded_model * add test for hybrid checkpointio * implement naive loading for sharded model * implement efficient sharded model loading * open a new file for hybrid checkpoint_io * small fix * fix circular importing * fix docstring * arrange arguments and apis * small fixpull/4520/head^2
parent
de8a65babc
commit
44eab2b27f
|
@ -16,7 +16,7 @@ from torch.utils.data import DataLoader
|
|||
from torch.utils.data.distributed import DistributedSampler
|
||||
|
||||
from colossalai.amp.naive_amp.mixed_precision_optimizer import MixedPrecisionOptimizer
|
||||
from colossalai.checkpoint_io import CheckpointIO
|
||||
from colossalai.checkpoint_io import CheckpointIO, HypridParallelCheckpointIO
|
||||
from colossalai.cluster import ProcessGroupMesh
|
||||
from colossalai.interface import ModelWrapper, OptimizerWrapper
|
||||
from colossalai.pipeline.schedule import OneForwardOneBackwardSchedule
|
||||
|
@ -292,6 +292,7 @@ class HybridParallelPlugin(PipelinePluginBase):
|
|||
self.schedule = OneForwardOneBackwardSchedule(num_microbatches, self.stage_manager)
|
||||
self.tp_group = self.pg_mesh.get_group_along_axis(TP_AXIS)
|
||||
self.dp_group = self.pg_mesh.get_group_along_axis(DP_AXIS)
|
||||
self.pp_group = self.pg_mesh.get_group_along_axis(PP_AXIS)
|
||||
self.shard_config = ShardConfig(tensor_parallel_process_group=self.tp_group,
|
||||
pipeline_stage_manager=self.stage_manager,
|
||||
enable_tensor_parallelism=self.tp_size > 1,
|
||||
|
@ -460,7 +461,7 @@ class HybridParallelPlugin(PipelinePluginBase):
|
|||
**_kwargs)
|
||||
|
||||
def get_checkpoint_io(self) -> CheckpointIO:
|
||||
return None
|
||||
return HypridParallelCheckpointIO(self.dp_group, self.pp_group, self.tp_group)
|
||||
|
||||
def no_sync(self, model: Module) -> Iterator[None]:
|
||||
raise NotImplementedError
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
from .checkpoint_io_base import CheckpointIO
|
||||
from .general_checkpoint_io import GeneralCheckpointIO
|
||||
from .hybrid_parallel_checkpoint_io import HypridParallelCheckpointIO
|
||||
from .index_file import CheckpointIndexFile
|
||||
|
||||
__all__ = ['CheckpointIO', 'CheckpointIndexFile', 'GeneralCheckpointIO']
|
||||
__all__ = ['CheckpointIO', 'CheckpointIndexFile', 'GeneralCheckpointIO', 'HybridParallelCheckpointIO']
|
||||
|
|
|
@ -0,0 +1,316 @@
|
|||
import copy
|
||||
import gc
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from shutil import rmtree
|
||||
from typing import Any, Callable, Iterator, List, Optional, OrderedDict, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from torch.distributed import ProcessGroup
|
||||
from torch.optim import Optimizer
|
||||
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
|
||||
|
||||
from colossalai.cluster import ProcessGroupMesh
|
||||
from colossalai.tensor.d_tensor import (
|
||||
is_customized_distributed_tensor,
|
||||
is_distributed_tensor,
|
||||
to_global,
|
||||
to_global_for_customized_distributed_tensor,
|
||||
)
|
||||
|
||||
from .general_checkpoint_io import GeneralCheckpointIO
|
||||
from .index_file import CheckpointIndexFile
|
||||
from .utils import (
|
||||
StateDictSharder,
|
||||
calculate_tensor_size,
|
||||
gather_distributed_param,
|
||||
get_model_base_filenames,
|
||||
get_optimizer_base_filenames,
|
||||
get_shard_filename,
|
||||
is_safetensors_available,
|
||||
load_shard_state_dict,
|
||||
load_state_dict_into_model,
|
||||
save_param_groups,
|
||||
save_state_dict,
|
||||
save_state_dict_shards,
|
||||
)
|
||||
|
||||
try:
|
||||
from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX, _IncompatibleKeys
|
||||
except ImportError:
|
||||
_EXTRA_STATE_KEY_SUFFIX = '_extra_state'
|
||||
|
||||
|
||||
class HypridParallelCheckpointIO(GeneralCheckpointIO):
|
||||
"""
|
||||
CheckpointIO for Hybrid Parallel Training.
|
||||
|
||||
Args:
|
||||
dp_group (ProcessGroup): Process group along data parallel dimension.
|
||||
pp_group (ProcessGroup): Process group along pipeline parallel dimension.
|
||||
tp_group (ProcessGroup): Process group along tensor parallel dimension.
|
||||
"""
|
||||
|
||||
def __init__(self, dp_group: ProcessGroup, pp_group: ProcessGroup, tp_group: ProcessGroup) -> None:
|
||||
super().__init__()
|
||||
self.dp_group = dp_group
|
||||
self.pp_group = pp_group
|
||||
self.tp_group = tp_group
|
||||
self.dp_rank = dist.get_rank(self.dp_group)
|
||||
self.tp_rank = dist.get_rank(self.tp_group)
|
||||
self.pp_rank = dist.get_rank(self.pp_group)
|
||||
self.dp_size = dist.get_world_size(dp_group)
|
||||
self.pp_size = dist.get_world_size(pp_group)
|
||||
self.tp_size = dist.get_world_size(tp_group)
|
||||
|
||||
@staticmethod
|
||||
def _model_sharder(model: nn.Module,
|
||||
prefix: str = '',
|
||||
keep_vars: bool = False,
|
||||
size_per_shard: int = 1024) -> Iterator[Tuple[OrderedDict, int]]:
|
||||
# An internel method that breaks state_dict of model into shards within limited size.
|
||||
|
||||
state_dict_sharder = StateDictSharder(size_per_shard)
|
||||
|
||||
# Save parameters.
|
||||
for name, param in model.named_parameters():
|
||||
if param is None:
|
||||
continue
|
||||
# Gather tensor pieces when using tensor parallel.
|
||||
param_ = gather_distributed_param(param, keep_vars=False)
|
||||
block, block_size = state_dict_sharder.append(prefix + name, param_)
|
||||
if block is not None:
|
||||
yield block, block_size
|
||||
|
||||
# Save buffers.
|
||||
for name, buf in model.named_buffers():
|
||||
if buf is not None and name not in model._non_persistent_buffers_set:
|
||||
buffer = buf if keep_vars else buf.detach()
|
||||
block, block_size = state_dict_sharder.append(prefix + name, buffer)
|
||||
if block is not None:
|
||||
yield block, block_size
|
||||
|
||||
# Save extra states.
|
||||
extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX
|
||||
if getattr(model.__class__, "get_extra_state",
|
||||
torch.nn.Module.get_extra_state) is not torch.nn.Module.get_extra_state:
|
||||
extra_state = model.get_extra_state()
|
||||
block, block_size = state_dict_sharder.append(extra_state_key, extra_state)
|
||||
if block is not None:
|
||||
yield block, block_size
|
||||
|
||||
# Return the last block in sharder.
|
||||
yield state_dict_sharder.current_block, state_dict_sharder.current_block_size
|
||||
|
||||
@staticmethod
|
||||
def _optimizer_sharder(optimizer: Optimizer, size_per_shard: int = 1024):
|
||||
# An internel method that breaks state_dict of optimizer into shards within limited size.
|
||||
# TODO (Baizhou): Implement sharding feature of optimizer.
|
||||
pass
|
||||
|
||||
def save_sharded_model(self,
|
||||
model: nn.Module,
|
||||
checkpoint: str,
|
||||
gather_dtensor: bool = True,
|
||||
prefix: Optional[str] = None,
|
||||
size_per_shard: int = 1024,
|
||||
use_safetensors: bool = False) -> None:
|
||||
"""
|
||||
Save sharded model checkpoint under the given checkpointing path.
|
||||
The following files will be created under the path:
|
||||
- An index file (pytorch_model.bin.index.json) containing a map between model params/buffers and file names.
|
||||
- Multiple files that store state tensors of models.
|
||||
If pipeline parallelism is used, the filenames are in the form of "pytorch_model.<prefix>-stage-000XX-shard-000XX.bin".
|
||||
If pipeline parallelism is not used, "pytorch_model.<prefix>-000XX.bin"
|
||||
|
||||
|
||||
Args:
|
||||
model (nn.Module): Model on local device to be saved.
|
||||
checkpoint (str): Checkpointing path which should be a directory path.
|
||||
gather_dtensor (bool, optional): Whether to gather_dtensor, currently not used. Defaults to True.
|
||||
prefix (str, optional): Perfix of file to save. Defaults to None.
|
||||
size_per_shard (int, optional): Size per shard in MB. Defaults to 1024.
|
||||
use_safetensors (bool, optional): Whether to use safe tensors. Defaults to False.
|
||||
"""
|
||||
|
||||
if os.path.isfile(checkpoint):
|
||||
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
|
||||
return
|
||||
|
||||
Path(checkpoint).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Devices along the same dp_group share the same copies of model.
|
||||
# So only let the device with dp_rank == 0 save the model.
|
||||
if self.dp_rank != 0:
|
||||
return
|
||||
|
||||
# Then collect the sharded parameters & buffers along tp_group.
|
||||
# Only devices with tp_size == 0 are responsible for model saving.
|
||||
state_dict_shard = HypridParallelCheckpointIO._model_sharder(model, size_per_shard=size_per_shard)
|
||||
weights_name, save_index_file = get_model_base_filenames(prefix, use_safetensors)
|
||||
index_file = CheckpointIndexFile(checkpoint)
|
||||
control_saving = (self.tp_rank == 0)
|
||||
|
||||
if self.pp_size == 1:
|
||||
# When pipeline is not used, save the model shards as in general checkpointIO
|
||||
total_size = save_state_dict_shards(sharded_state_dict=state_dict_shard,
|
||||
checkpoint=checkpoint,
|
||||
index_file=index_file,
|
||||
base_filename=weights_name,
|
||||
is_master=control_saving,
|
||||
use_safetensors=use_safetensors)
|
||||
if control_saving:
|
||||
index_file.append_meta_data("total_size", total_size)
|
||||
index_file.write_index_file(save_index_file)
|
||||
logging.info(f"The model is split into checkpoint shards. "
|
||||
f"You can find where each parameters has been saved in the "
|
||||
f"index located at {save_index_file}.")
|
||||
|
||||
else:
|
||||
# When pipeline is used, each stage produces its own shard files and index files.
|
||||
# Index files belonging to each stage are saved under a temporary folder ./tmp_index_files/
|
||||
# After all the state_dicts have been saved, the master rank integrates all the index files into one final index file and deletes the tmp folder.
|
||||
|
||||
final_index_file_path = copy.deepcopy(save_index_file)
|
||||
tmp_index_file_folder = os.path.join(checkpoint, "tmp_index_files")
|
||||
Path(tmp_index_file_folder).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Manage filenames of sharded weights and index file for each pipeline stage.
|
||||
weights_name = weights_name.replace(".bin", f"-stage-{self.pp_rank:05d}-shard.bin")
|
||||
weights_name = weights_name.replace(".safetensors", f"-stage-{self.pp_rank:05d}-shard.safetensors")
|
||||
save_index_file = save_index_file.replace(".json", f"-stage-{self.pp_rank:05d}.json")
|
||||
save_index_file = os.path.join("tmp_index_files", save_index_file)
|
||||
|
||||
total_size = save_state_dict_shards(sharded_state_dict=state_dict_shard,
|
||||
checkpoint=checkpoint,
|
||||
index_file=index_file,
|
||||
base_filename=weights_name,
|
||||
is_master=control_saving,
|
||||
use_safetensors=use_safetensors)
|
||||
if control_saving:
|
||||
assert self.dp_rank == 0 and self.tp_rank == 0, "The saving process should have both dp_rank and tp_rank as 0."
|
||||
index_file.append_meta_data("total_size", total_size)
|
||||
index_file.write_index_file(save_index_file)
|
||||
else:
|
||||
return
|
||||
|
||||
dist.barrier(self.pp_group)
|
||||
|
||||
# The global master rank integrates the index files and clean the folder.
|
||||
if self.pp_rank == 0:
|
||||
final_index_file = CheckpointIndexFile(checkpoint)
|
||||
final_index_file.append_meta_data("total_size", 0)
|
||||
|
||||
for filename in os.listdir(tmp_index_file_folder):
|
||||
stage_index_file = CheckpointIndexFile.from_file(os.path.join(tmp_index_file_folder, filename))
|
||||
final_index_file.metadata["total_size"] += stage_index_file.metadata["total_size"]
|
||||
for weight, weight_filename in stage_index_file.weight_map.items():
|
||||
final_index_file.append_weight_map(weight, weight_filename)
|
||||
|
||||
final_index_file.write_index_file(final_index_file_path)
|
||||
rmtree(tmp_index_file_folder)
|
||||
logging.info(f"The model is split into checkpoint shards. "
|
||||
f"You can find where each parameters has been saved in the "
|
||||
f"index located at {final_index_file_path}.")
|
||||
|
||||
def load_sharded_model(self, model: nn.Module, checkpoint_index_file: Path, strict: bool = False):
|
||||
"""
|
||||
Load sharded model with the given path to index file of checkpoint folder.
|
||||
|
||||
Args:
|
||||
model (nn.Module): The model to be loaded.
|
||||
index_file_path (str): Path to the index file of checkpointing folder.
|
||||
strict (bool, optional): For name matching during loading state_dict. Defaults to False.
|
||||
This argument should be manually set to False since params on same device might be stored in different files.
|
||||
"""
|
||||
|
||||
# Check whether the checkpoint uses safetensors.
|
||||
use_safetensors = False
|
||||
if "safetensors" in checkpoint_index_file.name:
|
||||
use_safetensors = True
|
||||
|
||||
if use_safetensors and not is_safetensors_available():
|
||||
raise ImportError("`safe_serialization` requires the `safetensors` library: `pip install safetensors`.")
|
||||
|
||||
# Read checkpoint index file.
|
||||
ckpt_index_file = CheckpointIndexFile.from_file(checkpoint_index_file)
|
||||
ckpt_root_path = ckpt_index_file.root_path
|
||||
weight_map = ckpt_index_file.weight_map
|
||||
strict = False
|
||||
|
||||
# Load params & buffers to model.
|
||||
# Keep a record of loaded files so that file will not be repeatedly loaded.
|
||||
loaded_file = set()
|
||||
|
||||
def _load(name: str):
|
||||
if name not in weight_map:
|
||||
raise ValueError(f"{name} is not stored in checkpoint, please check your checkpointing configuration!")
|
||||
filename = weight_map[name]
|
||||
|
||||
# If this param/buffer has been loaded before, directly return.
|
||||
if filename in loaded_file:
|
||||
return
|
||||
|
||||
file_path = os.path.join(ckpt_root_path, filename)
|
||||
state_dict = load_shard_state_dict(Path(file_path), use_safetensors)
|
||||
missing_keys = []
|
||||
|
||||
load_state_dict_into_model(model,
|
||||
state_dict,
|
||||
missing_keys=missing_keys,
|
||||
strict=strict,
|
||||
load_sub_module=True)
|
||||
del state_dict
|
||||
loaded_file.add(filename)
|
||||
|
||||
# Load parameters.
|
||||
for name, _ in model.named_parameters():
|
||||
_load(name)
|
||||
|
||||
# Load buffers.
|
||||
for name, buf in model.named_buffers():
|
||||
if buf is not None and name not in model._non_persistent_buffers_set:
|
||||
_load(name)
|
||||
|
||||
# Load extra states.
|
||||
extra_state_key = _EXTRA_STATE_KEY_SUFFIX
|
||||
if getattr(model.__class__, "get_extra_state",
|
||||
torch.nn.Module.get_extra_state) is not torch.nn.Module.get_extra_state:
|
||||
_load(extra_state_key)
|
||||
|
||||
def save_sharded_optimizer(self,
|
||||
optimizer: Optimizer,
|
||||
checkpoint: str,
|
||||
gather_dtensor: bool = True,
|
||||
prefix: Optional[str] = None,
|
||||
size_per_shard: int = 1024):
|
||||
pass
|
||||
|
||||
def load_sharded_optimizer(self, optimizer: Optimizer, index_file_path: str, prefix: str):
|
||||
pass
|
||||
|
||||
def load_unsharded_model(self, model: nn.Module, checkpoint: str, strict: bool = True):
|
||||
# TODO(Baizhou): support this feature after implementing complete state_dict collection
|
||||
raise NotImplementedError
|
||||
|
||||
def save_unsharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, use_safetensors: bool):
|
||||
# TODO(Baizhou): support this feature after implementing complete state_dict collection
|
||||
raise NotImplementedError
|
||||
|
||||
def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool):
|
||||
# TODO(Baizhou): support this feature after implementing complete state_dict collection
|
||||
raise NotImplementedError
|
||||
|
||||
def load_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool):
|
||||
# TODO(Baizhou): support this feature after implementing complete state_dict collection
|
||||
raise NotImplementedError
|
||||
|
||||
def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str):
|
||||
"""
|
||||
Save lr scheduler to checkpoint but only on master process.
|
||||
"""
|
||||
if self.coordinator.is_master():
|
||||
super().save_lr_scheduler(lr_scheduler, checkpoint)
|
|
@ -13,7 +13,12 @@ from torch.optim import Optimizer
|
|||
|
||||
from colossalai.interface import OptimizerWrapper
|
||||
from colossalai.nn.optimizer import ColossalaiOptimizer
|
||||
from colossalai.tensor.d_tensor import is_distributed_tensor
|
||||
from colossalai.tensor.d_tensor import (
|
||||
is_customized_distributed_tensor,
|
||||
is_distributed_tensor,
|
||||
to_global,
|
||||
to_global_for_customized_distributed_tensor,
|
||||
)
|
||||
|
||||
SAFE_WEIGHTS_NAME = "model.safetensors"
|
||||
WEIGHTS_NAME = "pytorch_model.bin"
|
||||
|
@ -88,8 +93,28 @@ def is_safetensor_checkpoint(checkpoint_file_path: str) -> bool:
|
|||
return False
|
||||
|
||||
|
||||
def gather_distributed_param(param: torch.Tensor, keep_vars: bool = False):
|
||||
"""
|
||||
Gather the complete parameter for saving if passed in param is distributed.
|
||||
|
||||
Args:
|
||||
param (torch.Tensor): A model parameter, might be d_tensor.
|
||||
keep_vars (bool, optional): Whether to return the parameter in calculation graph. Defaults to False.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: the complete parameter
|
||||
"""
|
||||
param_ = param if keep_vars else param.detach()
|
||||
if is_distributed_tensor(param_):
|
||||
return to_global(param_)
|
||||
elif is_customized_distributed_tensor(param_):
|
||||
return to_global_for_customized_distributed_tensor(param_)
|
||||
else:
|
||||
return param_
|
||||
|
||||
|
||||
# ======================================
|
||||
# Helper functions for saving shard file
|
||||
# Helper classes and functions for saving shard file
|
||||
# ======================================
|
||||
def unwrap_optimizer(optimizer: OptimizerWrapper):
|
||||
'''
|
||||
|
@ -104,6 +129,31 @@ def unwrap_optimizer(optimizer: OptimizerWrapper):
|
|||
return unwrapped_optim
|
||||
|
||||
|
||||
class StateDictSharder:
|
||||
|
||||
def __init__(self, size_per_shard: int) -> None:
|
||||
self.max_shard_size = size_per_shard
|
||||
self.current_block = OrderedDict()
|
||||
self.current_block_size = 0
|
||||
|
||||
def append(self, name: str, tensor: torch.Tensor) -> Tuple[Optional[OrderedDict], int]:
|
||||
tensor_size = calculate_tensor_size(tensor)
|
||||
ret_block = None
|
||||
ret_block_size = 0
|
||||
|
||||
# before we return the current block and create a new block,
|
||||
# we need to ensure that the current block is not empty
|
||||
if self.current_block_size + tensor_size > self.max_shard_size and self.current_block_size > 0:
|
||||
ret_block = self.current_block
|
||||
ret_block_size = self.current_block_size
|
||||
self.current_block = OrderedDict()
|
||||
self.current_block_size = 0
|
||||
|
||||
self.current_block[name] = tensor
|
||||
self.current_block_size += tensor_size
|
||||
return ret_block, ret_block_size
|
||||
|
||||
|
||||
def save_state_dict_shards(sharded_state_dict: Iterator[Tuple[OrderedDict, int]],
|
||||
checkpoint: str,
|
||||
index_file: "CheckpointIndexFile",
|
||||
|
@ -126,9 +176,10 @@ def save_state_dict_shards(sharded_state_dict: Iterator[Tuple[OrderedDict, int]]
|
|||
|
||||
total_size = 0
|
||||
for idx, shard_pair in enumerate(sharded_state_dict):
|
||||
if not is_master:
|
||||
continue
|
||||
shard, current_size = shard_pair
|
||||
if not is_master:
|
||||
del shard
|
||||
continue
|
||||
shard_file = get_shard_filename(base_filename, idx)
|
||||
total_size = total_size + current_size
|
||||
for key in shard.keys():
|
||||
|
@ -137,6 +188,7 @@ def save_state_dict_shards(sharded_state_dict: Iterator[Tuple[OrderedDict, int]]
|
|||
|
||||
# Only save on master rank.
|
||||
save_state_dict(shard, checkpoint_file_path, use_safetensors=use_safetensors)
|
||||
del shard
|
||||
|
||||
return total_size
|
||||
|
||||
|
|
|
@ -10,6 +10,7 @@ import torch.nn as nn
|
|||
from torch.distributed import ProcessGroup
|
||||
from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX, Module
|
||||
|
||||
from colossalai.checkpoint_io.utils import gather_distributed_param
|
||||
from colossalai.tensor.d_tensor import (
|
||||
distribute_tensor,
|
||||
distribute_tensor_with_customization,
|
||||
|
@ -56,13 +57,7 @@ class ParallelModule(nn.Module, ABC):
|
|||
"""
|
||||
for name, param in self._parameters.items():
|
||||
if param is not None:
|
||||
param_ = param if keep_vars else param.detach()
|
||||
if is_distributed_tensor(param_):
|
||||
destination[prefix + name] = to_global(param_)
|
||||
elif is_customized_distributed_tensor(param_):
|
||||
destination[prefix + name] = to_global_for_customized_distributed_tensor(param_)
|
||||
else:
|
||||
destination[prefix + name] = param_
|
||||
destination[prefix + name] = gather_distributed_param(param, keep_vars=keep_vars)
|
||||
|
||||
for name, buf in self._buffers.items():
|
||||
if buf is not None and name not in self._non_persistent_buffers_set:
|
||||
|
|
|
@ -8,7 +8,7 @@ import torch
|
|||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai.checkpoint_io.utils import calculate_tensor_size
|
||||
from colossalai.checkpoint_io.utils import StateDictSharder
|
||||
from colossalai.lazy import LazyTensor
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.nn.parallel.data_parallel import ColoDDP, _cast_float, free_storage
|
||||
|
@ -657,7 +657,7 @@ class ZeroDDP(ColoDDP):
|
|||
Yields:
|
||||
Iterator[OrderedDict]: A generator of state dict shard
|
||||
"""
|
||||
sharder = _StateDictSharder(max_shard_size)
|
||||
sharder = StateDictSharder(max_shard_size)
|
||||
|
||||
# get the mapping between copies and fp16 parameters
|
||||
fp16_to_fp32 = dict()
|
||||
|
@ -705,30 +705,6 @@ class ZeroDDP(ColoDDP):
|
|||
yield sharder.current_block, sharder.current_block_size
|
||||
|
||||
|
||||
class _StateDictSharder:
|
||||
|
||||
def __init__(self, max_shard_size: int) -> None:
|
||||
self.max_shard_size = max_shard_size
|
||||
self.current_block = OrderedDict()
|
||||
self.current_block_size = 0
|
||||
|
||||
def append(self, name: str, tensor: torch.Tensor) -> Tuple[Optional[OrderedDict], int]:
|
||||
tensor_size = calculate_tensor_size(tensor)
|
||||
ret_block = None
|
||||
ret_block_size = 0
|
||||
|
||||
# before we return the current block and create a new block,
|
||||
# we need to ensure that the current block is not empty
|
||||
if self.current_block_size + tensor_size > self.max_shard_size and self.current_block_size > 0:
|
||||
ret_block = self.current_block
|
||||
ret_block_size = self.current_block_size
|
||||
self.current_block = OrderedDict()
|
||||
self.current_block_size = 0
|
||||
self.current_block[name] = tensor
|
||||
self.current_block_size += tensor_size
|
||||
return ret_block, ret_block_size
|
||||
|
||||
|
||||
class GeminiDDP(ZeroDDP):
|
||||
|
||||
def __init__(self,
|
||||
|
|
|
@ -0,0 +1,116 @@
|
|||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.optim import Adam
|
||||
from utils import shared_tempdir
|
||||
|
||||
import colossalai
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.booster.plugin import HybridParallelPlugin
|
||||
from colossalai.tensor.d_tensor.api import clear_layout_converter
|
||||
from colossalai.testing import (
|
||||
check_state_dict_equal,
|
||||
clear_cache_before_run,
|
||||
parameterize,
|
||||
rerun_if_address_is_in_use,
|
||||
spawn,
|
||||
)
|
||||
from tests.kit.model_zoo import model_zoo
|
||||
|
||||
|
||||
@clear_cache_before_run()
|
||||
@parameterize('shard', [True])
|
||||
@parameterize('model_name', ['transformers_gpt'])
|
||||
@parameterize('size_per_shard', [32])
|
||||
@parameterize('test_config', [{
|
||||
'tp_size': 2,
|
||||
'pp_size': 2,
|
||||
'num_microbatches': 4,
|
||||
'precision': 'fp32',
|
||||
}, {
|
||||
'tp_size': 1,
|
||||
'pp_size': 2,
|
||||
'num_microbatches': 4,
|
||||
'precision': 'fp32',
|
||||
}, {
|
||||
'tp_size': 4,
|
||||
'pp_size': 1,
|
||||
'precision': 'fp32',
|
||||
}, {
|
||||
'tp_size': 2,
|
||||
'pp_size': 1,
|
||||
'precision': 'fp32',
|
||||
}, {
|
||||
'tp_size': 2,
|
||||
'pp_size': 1,
|
||||
'zero_stage': 2,
|
||||
'precision': 'fp16',
|
||||
'initial_scale': 1
|
||||
}])
|
||||
def exam_state_dict(shard: bool, model_name: str, size_per_shard: int, test_config: dict):
|
||||
|
||||
(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
||||
_) = next(iter(model_zoo.get_sub_registry(model_name).values()))
|
||||
criterion = loss_fn
|
||||
plugin = HybridParallelPlugin(**test_config)
|
||||
booster = Booster(plugin=plugin)
|
||||
|
||||
def _criterion(outputs, inputs):
|
||||
outputs = output_transform_fn(outputs)
|
||||
loss = criterion(outputs)
|
||||
return loss
|
||||
|
||||
model = model_fn().cuda()
|
||||
optimizer = Adam(model.parameters(), lr=1e-3)
|
||||
model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion)
|
||||
|
||||
new_model = model_fn().cuda()
|
||||
new_optimizer = Adam(new_model.parameters(), lr=1e-3)
|
||||
new_model, new_optimizer, criterion, _, _ = booster.boost(new_model, new_optimizer, criterion)
|
||||
|
||||
data = data_gen_fn()
|
||||
model.train()
|
||||
if booster.plugin.stage_manager is not None:
|
||||
for k, v in data.items():
|
||||
if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__:
|
||||
new_shape = [1] * v.dim()
|
||||
new_shape[0] = 4
|
||||
data[k] = v.to('cuda').repeat(*new_shape)
|
||||
data_iter = iter([data])
|
||||
output = booster.execute_pipeline(data_iter,
|
||||
model,
|
||||
_criterion,
|
||||
optimizer,
|
||||
return_loss=True,
|
||||
return_outputs=False)
|
||||
else:
|
||||
data = {k: v.cuda() for k, v in data.items()}
|
||||
output = model(**data)
|
||||
loss = criterion(output)
|
||||
optimizer.backward(loss)
|
||||
|
||||
optimizer.step()
|
||||
|
||||
with shared_tempdir() as tempdir:
|
||||
model_ckpt_path = f"{tempdir}/model"
|
||||
# optimizer_ckpt_path = f"{tempdir}/optimizer"
|
||||
booster.save_model(model, model_ckpt_path, shard=shard, size_per_shard=size_per_shard)
|
||||
# booster.save_optimizer(optimizer, optimizer_ckpt_path, shard=shard, size_per_shard=size_per_shard)
|
||||
dist.barrier()
|
||||
booster.load_model(new_model, model_ckpt_path)
|
||||
check_state_dict_equal(model.unwrap().state_dict(), new_model.unwrap().state_dict(), False)
|
||||
|
||||
clear_layout_converter()
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
config = {}
|
||||
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
exam_state_dict()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize('world_size', [4])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_hybrid_ckpIO(world_size):
|
||||
spawn(run_dist, world_size)
|
Loading…
Reference in New Issue