mirror of https://github.com/hpcaitech/ColossalAI
[checkpointio] support unsharded checkpointIO for hybrid parallel (#4774)
* support unsharded saving/loading for model * support optimizer unsharded saving * update doc * support unsharded loading for optimizer * small fixpull/4807/head^2
parent
a2db75546d
commit
64a08b2dc3
|
@ -9,7 +9,6 @@ import torch
|
|||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from torch.distributed import ProcessGroup
|
||||
from torch.optim import Optimizer
|
||||
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
|
||||
|
||||
from colossalai.cluster import DistCoordinator
|
||||
|
@ -24,10 +23,12 @@ from .utils import (
|
|||
get_optimizer_base_filenames,
|
||||
is_safetensors_available,
|
||||
load_shard_state_dict,
|
||||
load_state_dict,
|
||||
load_state_dict_into_model,
|
||||
load_states_into_optimizer,
|
||||
save_config_file,
|
||||
save_param_groups,
|
||||
save_state_dict,
|
||||
save_state_dict_shards,
|
||||
search_tp_partition_dim,
|
||||
sharded_optimizer_loading_epilogue,
|
||||
|
@ -119,13 +120,13 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
|||
use_zero: bool,
|
||||
dp_group: ProcessGroup,
|
||||
tp_group: ProcessGroup,
|
||||
master_to_working_map: Optional[Dict[int, torch.Tensor]] = None,
|
||||
size_per_shard: int = 1024,
|
||||
):
|
||||
# An internel method that breaks state_dict of optimizer into shards within limited size.
|
||||
|
||||
state_dict_sharder = StateDictSharder(size_per_shard)
|
||||
param_info = optimizer.param_info
|
||||
master_to_working_map = optimizer.get_master_to_working_map()
|
||||
|
||||
for param, state in optimizer.optim.state.items():
|
||||
if param is None:
|
||||
|
@ -217,7 +218,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
|||
index_file.append_meta_data("total_size", total_size)
|
||||
index_file.write_index_file(save_index_file)
|
||||
save_config_file(model, checkpoint)
|
||||
if self.verbose:
|
||||
if self.verbose and self.coordinator.is_master():
|
||||
logging.info(
|
||||
f"The model is split into checkpoint shards. "
|
||||
f"You can find where each parameters has been saved in the "
|
||||
|
@ -273,7 +274,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
|||
final_index_file.write_index_file(final_index_file_path)
|
||||
save_config_file(model, checkpoint)
|
||||
rmtree(tmp_index_file_folder)
|
||||
if self.verbose:
|
||||
if self.verbose and self.coordinator.is_master():
|
||||
logging.info(
|
||||
f"The model is split into checkpoint shards. "
|
||||
f"You can find where each parameters has been saved in the "
|
||||
|
@ -353,7 +354,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
|||
# Update master params if mixed-precision training is enabled.
|
||||
model_before_wrapping.update_master_params()
|
||||
|
||||
if self.verbose:
|
||||
if self.verbose and self.coordinator.is_master():
|
||||
logging.info(f"The model has been successfully loaded from sharded checkpoint: {ckpt_root_path}.")
|
||||
|
||||
def save_sharded_optimizer(
|
||||
|
@ -399,7 +400,6 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
|||
use_zero=self.use_zero,
|
||||
dp_group=self.dp_group,
|
||||
tp_group=self.tp_group,
|
||||
master_to_working_map=optimizer.get_master_to_working_map(),
|
||||
size_per_shard=size_per_shard,
|
||||
)
|
||||
states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix)
|
||||
|
@ -424,7 +424,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
|||
# Store index file.
|
||||
index_file.append_meta_data("total_size", total_size)
|
||||
index_file.write_index_file(save_index_file)
|
||||
if self.verbose:
|
||||
if self.verbose and self.coordinator.is_master():
|
||||
logging.info(
|
||||
f"The optimizer is going to be split to checkpoint shards. "
|
||||
f"You can find where each parameters has been saved in the "
|
||||
|
@ -484,7 +484,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
|||
final_index_file.write_index_file(final_index_file_path)
|
||||
rmtree(tmp_index_file_folder)
|
||||
|
||||
if self.verbose:
|
||||
if self.verbose and self.coordinator.is_master():
|
||||
logging.info(
|
||||
f"The model is split into checkpoint shards. "
|
||||
f"You can find where each parameters has been saved in the "
|
||||
|
@ -579,24 +579,196 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
|||
optimizer.optim.state[param] = sharded_state
|
||||
|
||||
sharded_optimizer_loading_epilogue(optimizer.optim)
|
||||
if self.verbose:
|
||||
if self.verbose and self.coordinator.is_master():
|
||||
logging.info(f"The optimizer has been successfully loaded from sharded checkpoint: {ckpt_root_path}.")
|
||||
|
||||
def load_unsharded_model(self, model: nn.Module, checkpoint: str, strict: bool = True):
|
||||
# TODO(Baizhou): support this feature after implementing complete state_dict collection
|
||||
raise NotImplementedError
|
||||
def save_unsharded_model(self, model: ModelWrapper, checkpoint: str, gather_dtensor: bool, use_safetensors: bool):
|
||||
"""
|
||||
Save model state dict to a single file with given checkpointing path.
|
||||
|
||||
def save_unsharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, use_safetensors: bool):
|
||||
# TODO(Baizhou): support this feature after implementing complete state_dict collection
|
||||
raise NotImplementedError
|
||||
Args:
|
||||
model (nn.Module): Model on local device to be saved.
|
||||
checkpoint (str): Checkpointing path which should be a file path. Can be absolute or relative path.
|
||||
gather_dtensor (bool, optional): Whether to gather dtensor, currently not used. Defaults to True.
|
||||
use_safetensors (bool, optional): Whether to use safe tensors. Defaults to False.
|
||||
"""
|
||||
if self.coordinator.is_master():
|
||||
logging.warning("Please avoid using unsharded checkpointing methods when dealing with large models!")
|
||||
|
||||
def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool):
|
||||
# TODO(Baizhou): support this feature after implementing complete state_dict collection
|
||||
raise NotImplementedError
|
||||
assert isinstance(model, ModelWrapper), "Please boost the model before saving!"
|
||||
model = model.unwrap()
|
||||
|
||||
def load_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool):
|
||||
# TODO(Baizhou): support this feature after implementing complete state_dict collection
|
||||
raise NotImplementedError
|
||||
if self.dp_rank != 0:
|
||||
return
|
||||
|
||||
# The logic of collecting parameter shards along tp degree
|
||||
# has been implemented by _save_to_state_dict method of ParallelModule in Shardformer.
|
||||
state_dict = model.state_dict()
|
||||
|
||||
if self.pp_size == 1:
|
||||
# When pipeline is not used, let master rank directly save the collected state_dict.
|
||||
if self.tp_rank == 0:
|
||||
save_state_dict(state_dict, checkpoint, use_safetensors)
|
||||
else:
|
||||
# When pipeline is used, first collect state_dict from every pipeline stage, then save the complete state_dict.
|
||||
state_dict_list = [None for _ in range(self.pp_size)]
|
||||
dist.barrier(self.pp_group)
|
||||
dist.all_gather_object(state_dict_list, state_dict, self.pp_group)
|
||||
|
||||
# Only the master rank do the saving.
|
||||
if self.coordinator.is_master():
|
||||
complete_state_dict = dict()
|
||||
for _state_dict in state_dict_list:
|
||||
complete_state_dict.update(_state_dict)
|
||||
save_state_dict(complete_state_dict, checkpoint, use_safetensors)
|
||||
|
||||
def load_unsharded_model(self, model: ModelWrapper, checkpoint: str, strict: bool = False):
|
||||
"""
|
||||
Load model from a single file with the given path of checkpoint.
|
||||
|
||||
Args:
|
||||
model (nn.Module): The model to be loaded.
|
||||
checkpoint_index_file (str): Path to the checkpoint file.
|
||||
strict (bool, optional): For name matching during loading state_dict. Defaults to False.
|
||||
This argument should be manually set to False since not all params in checkpoint are needed for each device when pipeline is enabled.
|
||||
"""
|
||||
if self.coordinator.is_master():
|
||||
logging.warning("Please avoid using unsharded checkpointing methods when dealing with large models!")
|
||||
|
||||
assert isinstance(model, ModelWrapper), "Please boost the model before loading!"
|
||||
strict = False
|
||||
model_before_wrapping = model
|
||||
model = model.unwrap()
|
||||
|
||||
# Load from checkpoint. Since the logic of breaking parameter shards along tp degree
|
||||
# has been implemented by _load_from_state_dict method of ParallelModule in Shardformer,
|
||||
# model.load_state_dict can be directly called.
|
||||
state_dict = load_state_dict(checkpoint)
|
||||
model.load_state_dict(state_dict, strict=strict)
|
||||
|
||||
# Update master params if mixed-precision training is enabled.
|
||||
model_before_wrapping.update_master_params()
|
||||
|
||||
def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool):
|
||||
"""
|
||||
Save optimizer state dict to a file with given path.
|
||||
|
||||
Args:
|
||||
optimizer (OptimizerWrapper): Optimizer to save sharded state_dict.
|
||||
checkpoint (str): Path to save optimizer state_dict.
|
||||
gather_dtensor (bool): Whether to gather_dtensor, not used.
|
||||
"""
|
||||
if self.coordinator.is_master():
|
||||
logging.warning("Please avoid using unsharded checkpointing methods when dealing with large models!")
|
||||
|
||||
assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before saving!"
|
||||
|
||||
# optimizer states of parameters kept by local device('s pipeline stage)
|
||||
local_states = dict()
|
||||
|
||||
for param, state in optimizer.optim.state.items():
|
||||
if param is None:
|
||||
continue
|
||||
|
||||
# working param is needed for obtaining correct param_id
|
||||
master_to_working_map = optimizer.get_master_to_working_map()
|
||||
if master_to_working_map is not None:
|
||||
working_param = master_to_working_map[id(param)]
|
||||
else:
|
||||
working_param = param
|
||||
|
||||
# gather complete state from tp shards & dp shards
|
||||
param_id = optimizer.param_info["param2id"][id(working_param)]
|
||||
original_shape = optimizer.param_info["param2shape"][id(working_param)]
|
||||
local_states[param_id] = HybridParallelCheckpointIO.gather_from_sharded_optimizer_state(
|
||||
state,
|
||||
working_param,
|
||||
original_shape=original_shape,
|
||||
dp_group=self.dp_group,
|
||||
tp_group=self.tp_group,
|
||||
use_zero=self.use_zero,
|
||||
inplace=False,
|
||||
device=torch.device("cuda"),
|
||||
)
|
||||
|
||||
if self.pp_size == 1:
|
||||
# When pipeline is not used, let master rank directly save the collected state_dict.
|
||||
state_dict = {"param_groups": optimizer.param_info["param_groups"], "state": local_states}
|
||||
if self.coordinator.is_master():
|
||||
save_state_dict(state_dict, checkpoint, use_safetensors=False)
|
||||
else:
|
||||
# When pipeline is used, first collect state_dict from every pipeline stage, then save the complete state_dict.
|
||||
states_list = [None for _ in range(self.pp_size)]
|
||||
dist.barrier(self.pp_group)
|
||||
dist.all_gather_object(states_list, local_states, self.pp_group)
|
||||
|
||||
# Only the master rank do the saving.
|
||||
if self.coordinator.is_master():
|
||||
state_dict = {"param_groups": optimizer.param_info["param_groups"], "state": dict()}
|
||||
for _states in states_list:
|
||||
state_dict["state"].update(_states)
|
||||
save_state_dict(state_dict, checkpoint, use_safetensors=False)
|
||||
|
||||
def load_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str):
|
||||
"""
|
||||
Load optimizer from a file with given path.
|
||||
|
||||
Args:
|
||||
optimizer (OptimizerWrapper): The optimizer to be loaded.
|
||||
checkpoint_index_file (str): Path to the checkpoint file.
|
||||
"""
|
||||
|
||||
def _get_param_id_from_optimizer_param(
|
||||
param: torch.Tensor, master_to_working_map: Optional[Dict[int, torch.Tensor]] = None
|
||||
):
|
||||
if master_to_working_map is not None:
|
||||
working_param = master_to_working_map[id(param)]
|
||||
else:
|
||||
working_param = param
|
||||
return optimizer.param_info["param2id"][id(working_param)]
|
||||
|
||||
if self.coordinator.is_master():
|
||||
logging.warning("Please avoid using unsharded checkpointing methods when dealing with large models!")
|
||||
|
||||
assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before loading!"
|
||||
|
||||
# Complete optimizer state_dict loaded from checkpoint, need to be processed later.
|
||||
state_dict = load_state_dict(checkpoint)
|
||||
|
||||
# Load param_groups.
|
||||
updated_groups = []
|
||||
saved_groups = state_dict["param_groups"]
|
||||
for old_pg, saved_pg in zip(optimizer.optim.param_groups, saved_groups):
|
||||
new_pg = copy.deepcopy(saved_pg)
|
||||
new_pg["params"] = old_pg["params"] # Only keep the parameters kept by current pipeline stage.
|
||||
updated_groups.append(new_pg)
|
||||
optimizer.optim.__dict__.update({"param_groups": updated_groups})
|
||||
|
||||
# Load saved states to optimizer. First discard those states not belonging to current pipeline stage.
|
||||
master_to_working_map = optimizer.get_master_to_working_map()
|
||||
id_map = {}
|
||||
for pg in optimizer.optim.param_groups:
|
||||
for param in pg["params"]:
|
||||
param_id = _get_param_id_from_optimizer_param(param, master_to_working_map)
|
||||
id_map[param_id] = param
|
||||
load_states_into_optimizer(optimizer.optim, state_dict["state"], id_map, strict=True)
|
||||
|
||||
# Then shard the loaded optimizer states if using tp/zero.
|
||||
for param, state in optimizer.optim.state.items():
|
||||
if param is None:
|
||||
continue
|
||||
device = param.device
|
||||
if master_to_working_map is not None:
|
||||
working_param = master_to_working_map[id(param)]
|
||||
else:
|
||||
working_param = param
|
||||
original_shape = optimizer.param_info["param2shape"][id(working_param)]
|
||||
sharded_state = self.shard_from_complete_optimizer_state(
|
||||
state, current_shape=working_param.shape, original_shape=original_shape, device=device, inplace=True
|
||||
)
|
||||
optimizer.optim.state[param] = sharded_state
|
||||
|
||||
sharded_optimizer_loading_epilogue(optimizer.optim)
|
||||
|
||||
def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str):
|
||||
"""
|
||||
|
@ -614,6 +786,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
|||
tp_group: ProcessGroup,
|
||||
use_zero: bool,
|
||||
inplace: bool,
|
||||
device: torch.device = torch.device("cpu"),
|
||||
) -> OrderedDict:
|
||||
"""
|
||||
With given parameter and its optimizer states, gather the complete optimizer state for saving.
|
||||
|
@ -626,6 +799,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
|||
tp_group (ProcessGroup): The process group of tensor parallel.
|
||||
use_zero (bool): Whether Zero is used.
|
||||
inplace (bool): If set to True, will update the values of argument 'state' in place. Else will make a copy of state.
|
||||
device (torch.device): The destination device of loaded optimizer states. Defaults to torch.device('cpu').
|
||||
|
||||
Returns:
|
||||
OrderedDict: The complete optimizer state of given parameter.
|
||||
|
@ -651,7 +825,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
|||
dist.all_gather(gather_tensor, v, group=tp_group)
|
||||
v = torch.cat(gather_tensor, dim=partition_dim)
|
||||
|
||||
state_[k] = v.detach().clone().cpu()
|
||||
state_[k] = v.detach().clone().to(device)
|
||||
|
||||
return state_
|
||||
|
||||
|
|
|
@ -74,8 +74,6 @@ This plugin implements the combination of various parallel training strategies a
|
|||
|
||||
> ⚠ When using this plugin, only the subset of Huggingface transformers supported by Shardformer are compatible with tensor parallel, pipeline parallel and optimization tools. Mainstream transformers such as Llama 1, Llama 2, OPT, Bloom, Bert and GPT2 etc. are all supported by Shardformer.
|
||||
|
||||
> ⚠ This plugin only supports sharded checkpointing methods for model/optimizer at present. Unsharded checkpointing methods will be supported in future release.
|
||||
|
||||
{{ autodoc:colossalai.booster.plugin.HybridParallelPlugin }}
|
||||
|
||||
### Torch DDP Plugin
|
||||
|
|
|
@ -71,8 +71,6 @@ Zero-2 不支持局部梯度累积。如果您坚持使用,虽然可以积累
|
|||
|
||||
> ⚠ 在使用该插件的时候, 只有支持Shardformer的部分Huggingface transformers模型才能够使用张量并行、流水线并行以及优化工具。Llama 1、Llama 2、OPT、Bloom、Bert以及GPT2等主流transformers模型均已支持Shardformer。
|
||||
|
||||
> ⚠ 该插件当前只对模型和优化器支持分片的checkpoint方法。不分片的checkpoint方法会在未来的版本中被支持。
|
||||
|
||||
{{ autodoc:colossalai.booster.plugin.HybridParallelPlugin }}
|
||||
|
||||
### Torch DDP 插件
|
||||
|
|
|
@ -20,9 +20,8 @@ from colossalai.testing import (
|
|||
from tests.kit.model_zoo import model_zoo
|
||||
|
||||
|
||||
# TODO (Baizhou): Add test cases for shard=False
|
||||
@clear_cache_before_run()
|
||||
@parameterize("shard", [True])
|
||||
@parameterize("shard", [True, False])
|
||||
@parameterize("model_name", ["transformers_gpt"])
|
||||
@parameterize("size_per_shard", [32])
|
||||
@parameterize(
|
||||
|
|
Loading…
Reference in New Issue