mirror of https://github.com/hpcaitech/ColossalAI
[shardformer] support sharded optimizer checkpointIO of HybridParallelPlugin (#4540)
* implement sharded optimizer saving * add more param info * finish implementation of sharded optimizer saving * fix bugs in optimizer sharded saving * add pp+zero test * param group loading * greedy loading of optimizer * fix bug when loading * implement optimizer sharded saving * add optimizer test & arrange checkpointIO utils * fix gemini sharding state_dict * add verbose option * add loading of master params * fix typehint * fix master/working mapping in fp16 amppull/4566/head
parent
2c787d7f47
commit
c9625dbb63
|
@ -1,7 +1,7 @@
|
|||
import random
|
||||
from contextlib import nullcontext
|
||||
from functools import partial
|
||||
from typing import Any, Callable, Iterator, List, Optional, Tuple, Union
|
||||
from typing import Any, Callable, Iterator, List, Optional, OrderedDict, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
@ -110,6 +110,36 @@ class HybridParallelModule(ModelWrapper):
|
|||
return module
|
||||
|
||||
|
||||
def get_param_info(optim: Optimizer):
|
||||
# Get a backup of necessary information of parameters for future use, which includes:
|
||||
# 1. A complete param_group, with params in the form of param_id
|
||||
# 2. A mapping from param address (obtained using id(param)) to integer param_id
|
||||
# 3. A mapping from integer param_id to param address.
|
||||
# 4. A mapping from param_address (obtained using id(param)) to the original shape of parameter before sharding.
|
||||
# When Zero is used, the params here are fp16/bf16 model params rather than fp32 master params in optimizer.
|
||||
|
||||
if optim is None:
|
||||
return {}
|
||||
param_info = {'param_groups': [], 'param2id': {}, 'id2param': {}, 'param2shape': {}}
|
||||
start_index = 0
|
||||
for group in optim.param_groups:
|
||||
|
||||
packed_group = {k: v for k, v in group.items() if k != 'params'}
|
||||
packed_group['params'] = []
|
||||
|
||||
for param_id, param in enumerate(group['params'], start_index):
|
||||
original_shape = param.shape if isinstance(param, torch.Tensor) else None
|
||||
packed_group['params'].append(param_id)
|
||||
param_info['param2id'][id(param)] = param_id
|
||||
param_info['id2param'][param_id] = id(param)
|
||||
param_info['param2shape'][id(param)] = original_shape
|
||||
|
||||
param_info['param_groups'].append(packed_group)
|
||||
start_index += len(group['params'])
|
||||
|
||||
return param_info
|
||||
|
||||
|
||||
def init_pipeline_optimizer(optim: Optimizer, model: Module):
|
||||
params = set(model.parameters())
|
||||
new_param_groups = []
|
||||
|
@ -121,7 +151,8 @@ def init_pipeline_optimizer(optim: Optimizer, model: Module):
|
|||
|
||||
class HybridParallelNaiveOptimizer(OptimizerWrapper):
|
||||
|
||||
def __init__(self, optim: Optimizer, model: Module, use_pipeline: bool):
|
||||
def __init__(self, optim: Optimizer, model: Module, use_pipeline: bool, param_info: OrderedDict):
|
||||
self.param_info = param_info
|
||||
if use_pipeline:
|
||||
init_pipeline_optimizer(optim, model)
|
||||
super().__init__(optim)
|
||||
|
@ -133,6 +164,7 @@ class HybridParallelAMPOptimizer(MixedPrecisionOptimizer):
|
|||
optim: Optimizer,
|
||||
model: Module,
|
||||
use_pipeline: bool,
|
||||
param_info: OrderedDict,
|
||||
precision: str = 'fp16',
|
||||
initial_scale: float = 2**16,
|
||||
min_scale: float = 1,
|
||||
|
@ -142,6 +174,7 @@ class HybridParallelAMPOptimizer(MixedPrecisionOptimizer):
|
|||
hysteresis: int = 2,
|
||||
max_scale: float = 2**32,
|
||||
max_norm: float = 0):
|
||||
self.param_info = param_info
|
||||
if use_pipeline:
|
||||
init_pipeline_optimizer(optim, model)
|
||||
super().__init__(optim, precision, initial_scale, min_scale, growth_factor, backoff_factor, growth_interval,
|
||||
|
@ -155,6 +188,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
|
|||
optimizer: Optimizer,
|
||||
model: Module,
|
||||
use_pipeline: bool,
|
||||
param_info: OrderedDict,
|
||||
initial_scale: int = 2**16, # grad scaler config
|
||||
min_scale: int = 1,
|
||||
growth_factor: float = 2.,
|
||||
|
@ -172,6 +206,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
|
|||
dp_process_group: Optional[ProcessGroup] = None, # the dp pg for comm
|
||||
tp_process_group: Optional[ProcessGroup] = None, # if using tp
|
||||
forced_dtype: Optional[torch.dtype] = None):
|
||||
self.param_info = param_info
|
||||
if use_pipeline:
|
||||
init_pipeline_optimizer(optimizer, model)
|
||||
super().__init__(optimizer, initial_scale, min_scale, growth_factor, backoff_factor, growth_interval,
|
||||
|
@ -356,6 +391,7 @@ class HybridParallelPlugin(PipelinePluginBase):
|
|||
dataloader: Optional[DataLoader] = None,
|
||||
lr_scheduler: Optional[LRScheduler] = None,
|
||||
) -> Tuple[Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]:
|
||||
param_info = get_param_info(optimizer)
|
||||
if not isinstance(model, ModelWrapper):
|
||||
use_ddp = self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0
|
||||
model = HybridParallelModule(model, self.precision, self.shard_config, self.dp_group, use_ddp,
|
||||
|
@ -366,25 +402,33 @@ class HybridParallelPlugin(PipelinePluginBase):
|
|||
optimizer = HybridParallelAMPOptimizer(optimizer,
|
||||
model,
|
||||
use_pipeline=self.enable_pipeline_parallelism,
|
||||
param_info=param_info,
|
||||
precision=self.precision,
|
||||
max_norm=self.max_norm,
|
||||
**self.amp_config)
|
||||
self.checkpoint_io.link_master_and_working_param(optimizer.working_to_master_map,
|
||||
optimizer.master_to_working_map)
|
||||
else:
|
||||
optimizer = HybridParallelNaiveOptimizer(optimizer,
|
||||
model,
|
||||
use_pipeline=self.enable_pipeline_parallelism)
|
||||
use_pipeline=self.enable_pipeline_parallelism,
|
||||
param_info=param_info)
|
||||
else:
|
||||
assert self.dp_size > 1, "Please use Zero when data parallel size is greater than 1."
|
||||
assert self.precision != 'fp32', "Please set precision to 'fp16' or 'bf16' when using ZeRO."
|
||||
optimizer = HybridParallelZeroOptimizer(optimizer,
|
||||
model,
|
||||
use_pipeline=self.enable_pipeline_parallelism,
|
||||
param_info=param_info,
|
||||
dp_process_group=self.dp_group,
|
||||
tp_process_group=self.tp_group,
|
||||
verbose=True,
|
||||
clip_grad_norm=self.max_norm,
|
||||
**self.zero_config,
|
||||
**self.amp_config)
|
||||
self.checkpoint_io.link_master_and_working_param(optimizer._param_store.working_to_master_param,
|
||||
optimizer._param_store.master_to_working_param)
|
||||
|
||||
return model, optimizer, criterion, dataloader, lr_scheduler
|
||||
|
||||
def execute_pipeline(self,
|
||||
|
@ -461,7 +505,8 @@ class HybridParallelPlugin(PipelinePluginBase):
|
|||
**_kwargs)
|
||||
|
||||
def get_checkpoint_io(self) -> CheckpointIO:
|
||||
return HypridParallelCheckpointIO(self.dp_group, self.pp_group, self.tp_group)
|
||||
self.checkpoint_io = HypridParallelCheckpointIO(self.dp_group, self.pp_group, self.tp_group, self.zero_stage)
|
||||
return self.checkpoint_io
|
||||
|
||||
def no_sync(self, model: Module) -> Iterator[None]:
|
||||
raise NotImplementedError
|
||||
|
|
|
@ -4,7 +4,7 @@ import logging
|
|||
import os
|
||||
from pathlib import Path
|
||||
from shutil import rmtree
|
||||
from typing import Any, Callable, Iterator, List, Optional, OrderedDict, Tuple, Union
|
||||
from typing import Dict, Iterator, Optional, OrderedDict, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
@ -13,29 +13,23 @@ from torch.distributed import ProcessGroup
|
|||
from torch.optim import Optimizer
|
||||
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
|
||||
|
||||
from colossalai.cluster import ProcessGroupMesh
|
||||
from colossalai.tensor.d_tensor import (
|
||||
is_customized_distributed_tensor,
|
||||
is_distributed_tensor,
|
||||
to_global,
|
||||
to_global_for_customized_distributed_tensor,
|
||||
)
|
||||
from colossalai.interface import OptimizerWrapper
|
||||
|
||||
from .general_checkpoint_io import GeneralCheckpointIO
|
||||
from .index_file import CheckpointIndexFile
|
||||
from .utils import (
|
||||
StateDictSharder,
|
||||
calculate_tensor_size,
|
||||
gather_distributed_param,
|
||||
get_model_base_filenames,
|
||||
get_optimizer_base_filenames,
|
||||
get_shard_filename,
|
||||
is_safetensors_available,
|
||||
load_shard_state_dict,
|
||||
load_state_dict_into_model,
|
||||
load_states_into_optimizer,
|
||||
save_param_groups,
|
||||
save_state_dict,
|
||||
save_state_dict_shards,
|
||||
search_tp_partition_dim,
|
||||
sharded_optimizer_loading_epilogue,
|
||||
)
|
||||
|
||||
try:
|
||||
|
@ -52,9 +46,16 @@ class HypridParallelCheckpointIO(GeneralCheckpointIO):
|
|||
dp_group (ProcessGroup): Process group along data parallel dimension.
|
||||
pp_group (ProcessGroup): Process group along pipeline parallel dimension.
|
||||
tp_group (ProcessGroup): Process group along tensor parallel dimension.
|
||||
zero_stage (int): The zero stage of plugin. Should be in [0, 1, 2].
|
||||
verbose (bool, optional): Whether to print logging massage when saving/loading has been succesfully executed. Defaults to True.
|
||||
"""
|
||||
|
||||
def __init__(self, dp_group: ProcessGroup, pp_group: ProcessGroup, tp_group: ProcessGroup) -> None:
|
||||
def __init__(self,
|
||||
dp_group: ProcessGroup,
|
||||
pp_group: ProcessGroup,
|
||||
tp_group: ProcessGroup,
|
||||
zero_stage: int,
|
||||
verbose: bool = True) -> None:
|
||||
super().__init__()
|
||||
self.dp_group = dp_group
|
||||
self.pp_group = pp_group
|
||||
|
@ -65,6 +66,10 @@ class HypridParallelCheckpointIO(GeneralCheckpointIO):
|
|||
self.dp_size = dist.get_world_size(dp_group)
|
||||
self.pp_size = dist.get_world_size(pp_group)
|
||||
self.tp_size = dist.get_world_size(tp_group)
|
||||
self.use_zero = (zero_stage > 0)
|
||||
self.verbose = verbose
|
||||
self.working_to_master_map = None
|
||||
self.master_to_working_map = None
|
||||
|
||||
@staticmethod
|
||||
def _model_sharder(model: nn.Module,
|
||||
|
@ -81,7 +86,7 @@ class HypridParallelCheckpointIO(GeneralCheckpointIO):
|
|||
continue
|
||||
# Gather tensor pieces when using tensor parallel.
|
||||
param_ = gather_distributed_param(param, keep_vars=False)
|
||||
block, block_size = state_dict_sharder.append(prefix + name, param_)
|
||||
block, block_size = state_dict_sharder.append_param(prefix + name, param_)
|
||||
if block is not None:
|
||||
yield block, block_size
|
||||
|
||||
|
@ -89,7 +94,7 @@ class HypridParallelCheckpointIO(GeneralCheckpointIO):
|
|||
for name, buf in model.named_buffers():
|
||||
if buf is not None and name not in model._non_persistent_buffers_set:
|
||||
buffer = buf if keep_vars else buf.detach()
|
||||
block, block_size = state_dict_sharder.append(prefix + name, buffer)
|
||||
block, block_size = state_dict_sharder.append_param(prefix + name, buffer)
|
||||
if block is not None:
|
||||
yield block, block_size
|
||||
|
||||
|
@ -98,7 +103,7 @@ class HypridParallelCheckpointIO(GeneralCheckpointIO):
|
|||
if getattr(model.__class__, "get_extra_state",
|
||||
torch.nn.Module.get_extra_state) is not torch.nn.Module.get_extra_state:
|
||||
extra_state = model.get_extra_state()
|
||||
block, block_size = state_dict_sharder.append(extra_state_key, extra_state)
|
||||
block, block_size = state_dict_sharder.append_param(extra_state_key, extra_state)
|
||||
if block is not None:
|
||||
yield block, block_size
|
||||
|
||||
|
@ -106,10 +111,44 @@ class HypridParallelCheckpointIO(GeneralCheckpointIO):
|
|||
yield state_dict_sharder.current_block, state_dict_sharder.current_block_size
|
||||
|
||||
@staticmethod
|
||||
def _optimizer_sharder(optimizer: Optimizer, size_per_shard: int = 1024):
|
||||
def _optimizer_sharder(optimizer: OptimizerWrapper,
|
||||
use_zero: bool,
|
||||
dp_group: ProcessGroup,
|
||||
tp_group: ProcessGroup,
|
||||
master_to_working_map: Optional[Dict[int, torch.Tensor]] = None,
|
||||
size_per_shard: int = 1024):
|
||||
|
||||
# An internel method that breaks state_dict of optimizer into shards within limited size.
|
||||
# TODO (Baizhou): Implement sharding feature of optimizer.
|
||||
pass
|
||||
|
||||
state_dict_sharder = StateDictSharder(size_per_shard)
|
||||
param_info = optimizer.param_info
|
||||
|
||||
for param, state in optimizer.optim.state.items():
|
||||
|
||||
if param is None:
|
||||
continue
|
||||
|
||||
if master_to_working_map is not None:
|
||||
working_param = master_to_working_map[id(param)]
|
||||
else:
|
||||
working_param = param
|
||||
|
||||
param_id = param_info['param2id'][id(working_param)]
|
||||
original_shape = param_info['param2shape'][id(working_param)]
|
||||
state_ = HypridParallelCheckpointIO.gather_from_sharded_optimizer_state(state,
|
||||
working_param,
|
||||
original_shape=original_shape,
|
||||
dp_group=dp_group,
|
||||
tp_group=tp_group,
|
||||
use_zero=use_zero,
|
||||
inplace=False)
|
||||
|
||||
block, block_size = state_dict_sharder.append_optim_state(param_id, state_)
|
||||
if block is not None:
|
||||
yield block, block_size
|
||||
|
||||
# Return the last block in sharder.
|
||||
yield state_dict_sharder.current_block, state_dict_sharder.current_block_size
|
||||
|
||||
def save_sharded_model(self,
|
||||
model: nn.Module,
|
||||
|
@ -148,7 +187,7 @@ class HypridParallelCheckpointIO(GeneralCheckpointIO):
|
|||
return
|
||||
|
||||
# Then collect the sharded parameters & buffers along tp_group.
|
||||
# Only devices with tp_size == 0 are responsible for model saving.
|
||||
# Only devices with tp_rank == 0 are responsible for model saving.
|
||||
state_dict_shard = HypridParallelCheckpointIO._model_sharder(model, size_per_shard=size_per_shard)
|
||||
weights_name, save_index_file = get_model_base_filenames(prefix, use_safetensors)
|
||||
index_file = CheckpointIndexFile(checkpoint)
|
||||
|
@ -165,9 +204,10 @@ class HypridParallelCheckpointIO(GeneralCheckpointIO):
|
|||
if control_saving:
|
||||
index_file.append_meta_data("total_size", total_size)
|
||||
index_file.write_index_file(save_index_file)
|
||||
logging.info(f"The model is split into checkpoint shards. "
|
||||
f"You can find where each parameters has been saved in the "
|
||||
f"index located at {save_index_file}.")
|
||||
if self.verbose:
|
||||
logging.info(f"The model is split into checkpoint shards. "
|
||||
f"You can find where each parameters has been saved in the "
|
||||
f"index located at {save_index_file}.")
|
||||
|
||||
else:
|
||||
# When pipeline is used, each stage produces its own shard files and index files.
|
||||
|
@ -212,9 +252,10 @@ class HypridParallelCheckpointIO(GeneralCheckpointIO):
|
|||
|
||||
final_index_file.write_index_file(final_index_file_path)
|
||||
rmtree(tmp_index_file_folder)
|
||||
logging.info(f"The model is split into checkpoint shards. "
|
||||
f"You can find where each parameters has been saved in the "
|
||||
f"index located at {final_index_file_path}.")
|
||||
if self.verbose:
|
||||
logging.info(f"The model is split into checkpoint shards. "
|
||||
f"You can find where each parameters has been saved in the "
|
||||
f"index located at {final_index_file_path}.")
|
||||
|
||||
def load_sharded_model(self, model: nn.Module, checkpoint_index_file: Path, strict: bool = False):
|
||||
"""
|
||||
|
@ -222,7 +263,7 @@ class HypridParallelCheckpointIO(GeneralCheckpointIO):
|
|||
|
||||
Args:
|
||||
model (nn.Module): The model to be loaded.
|
||||
index_file_path (str): Path to the index file of checkpointing folder.
|
||||
checkpoint_index_file (str): Path to the index file of checkpointing folder.
|
||||
strict (bool, optional): For name matching during loading state_dict. Defaults to False.
|
||||
This argument should be manually set to False since params on same device might be stored in different files.
|
||||
"""
|
||||
|
@ -263,7 +304,6 @@ class HypridParallelCheckpointIO(GeneralCheckpointIO):
|
|||
missing_keys=missing_keys,
|
||||
strict=strict,
|
||||
load_sub_module=True)
|
||||
del state_dict
|
||||
loaded_file.add(filename)
|
||||
|
||||
# Load parameters.
|
||||
|
@ -271,8 +311,11 @@ class HypridParallelCheckpointIO(GeneralCheckpointIO):
|
|||
_load(name)
|
||||
|
||||
# Load buffers.
|
||||
non_persistent_buffers = set()
|
||||
for n, m in model.named_modules():
|
||||
non_persistent_buffers |= set('.'.join((n, b)) for b in m._non_persistent_buffers_set)
|
||||
for name, buf in model.named_buffers():
|
||||
if buf is not None and name not in model._non_persistent_buffers_set:
|
||||
if buf is not None and name not in non_persistent_buffers:
|
||||
_load(name)
|
||||
|
||||
# Load extra states.
|
||||
|
@ -281,16 +324,236 @@ class HypridParallelCheckpointIO(GeneralCheckpointIO):
|
|||
torch.nn.Module.get_extra_state) is not torch.nn.Module.get_extra_state:
|
||||
_load(extra_state_key)
|
||||
|
||||
# Update master params if mixed-precision training is enabled.
|
||||
with torch.no_grad():
|
||||
if self.working_to_master_map is not None:
|
||||
for param in model.parameters():
|
||||
if (param is None) or (id(param) not in self.working_to_master_map):
|
||||
continue
|
||||
master_param = self.working_to_master_map[id(param)]
|
||||
if self.use_zero:
|
||||
# master_param is sharded under Zero setting
|
||||
padding_size = (self.dp_size - param.numel() % self.dp_size) % self.dp_size
|
||||
if padding_size > 0:
|
||||
padded_param = torch.nn.functional.pad(param.data.view(-1), [0, padding_size])
|
||||
else:
|
||||
padded_param = param.data.view(-1)
|
||||
sharded_param = padded_param.split(padded_param.numel() // self.dp_size)[self.dp_rank]
|
||||
master_param.data.copy_(sharded_param.data)
|
||||
else:
|
||||
master_param.data.copy_(param.data)
|
||||
|
||||
if self.verbose:
|
||||
logging.info(f"The model has been successfully loaded from sharded checkpoint: {ckpt_root_path}.")
|
||||
|
||||
def save_sharded_optimizer(self,
|
||||
optimizer: Optimizer,
|
||||
optimizer: OptimizerWrapper,
|
||||
checkpoint: str,
|
||||
gather_dtensor: bool = True,
|
||||
prefix: Optional[str] = None,
|
||||
size_per_shard: int = 1024):
|
||||
pass
|
||||
"""
|
||||
Save sharded optimizer checkpoint under the given checkpointing path.
|
||||
The following files will be created under the path:
|
||||
- An index file (pytorch_optim.bin.index.json) containing a map between optimizer states and file names
|
||||
- A group file (pytorch_optim_group.bin) recording information of param_groups
|
||||
- Multiple files that store state tensors of optimizers.
|
||||
If pipeline parallelism is used, the filenames are in the form of "pytorch_optim.<prefix>-stage-000XX-shard-000XX.bin".
|
||||
If pipeline parallelism is not used, "pytorch_optim.<prefix>-000XX.bin"
|
||||
|
||||
def load_sharded_optimizer(self, optimizer: Optimizer, index_file_path: str, prefix: str):
|
||||
pass
|
||||
Args:
|
||||
optimizer (OptimizerWrapper): Optimizer to save sharded state_dict
|
||||
checkpoint (str): Path to save optimizer state_dict
|
||||
gather_dtensor (bool): Whether to gather_dtensor, not used
|
||||
prefix (str): Perfix of file to save
|
||||
size_per_shard (int): Max file size of each file shard that store state tensors
|
||||
"""
|
||||
if os.path.isfile(checkpoint):
|
||||
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
|
||||
return
|
||||
|
||||
Path(checkpoint).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Devices along the same dp_group share the same copies of states when zero is not used.
|
||||
# In this case only let the device with dp_rank == 0 save the model.
|
||||
if not self.use_zero and self.dp_rank != 0:
|
||||
return
|
||||
|
||||
# Then collect the sharded states along dp_group(if using zero)/tp_group.
|
||||
# Only devices with (dp_rank == 0 and tp_rank == 0) are responsible for states saving.
|
||||
state_dict_shard = HypridParallelCheckpointIO._optimizer_sharder(
|
||||
optimizer,
|
||||
use_zero=self.use_zero,
|
||||
dp_group=self.dp_group,
|
||||
tp_group=self.tp_group,
|
||||
master_to_working_map=self.master_to_working_map,
|
||||
size_per_shard=size_per_shard)
|
||||
states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix)
|
||||
index_file = CheckpointIndexFile(checkpoint)
|
||||
control_saving = (self.dp_rank == 0 and self.tp_rank == 0)
|
||||
|
||||
if self.pp_size == 1:
|
||||
# When pipeline is not used, save the optimizer shards as in general checkpointIO
|
||||
total_size = save_state_dict_shards(sharded_state_dict=state_dict_shard,
|
||||
checkpoint=checkpoint,
|
||||
index_file=index_file,
|
||||
base_filename=states_name,
|
||||
is_master=control_saving)
|
||||
|
||||
if control_saving:
|
||||
# Store param groups.
|
||||
index_file.append_meta_data("param_groups", param_group_file)
|
||||
group_file_path = os.path.join(checkpoint, param_group_file)
|
||||
save_param_groups(optimizer.param_info, group_file_path)
|
||||
# Store index file.
|
||||
index_file.append_meta_data("total_size", total_size)
|
||||
index_file.write_index_file(save_index_file)
|
||||
if self.verbose:
|
||||
logging.info(f"The optimizer is going to be split to checkpoint shards. "
|
||||
f"You can find where each parameters has been saved in the "
|
||||
f"index located at {save_index_file}.")
|
||||
|
||||
else:
|
||||
# When pipeline is used, each stage produces its own shard files and index files.
|
||||
# Index files belonging to each stage are saved under a temporary folder ./tmp_index_files/
|
||||
# After all the state_dicts have been saved, the master rank integrates all the index files into one final index file and deletes the tmp folder.
|
||||
|
||||
final_index_file_path = copy.deepcopy(save_index_file)
|
||||
tmp_index_file_folder = os.path.join(checkpoint, "tmp_index_files")
|
||||
Path(tmp_index_file_folder).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Manage filenames of sharded weights and index file for each pipeline stage.
|
||||
states_name = states_name.replace(".bin", f"-stage-{self.pp_rank:05d}-shard.bin")
|
||||
save_index_file = save_index_file.replace(".json", f"-stage-{self.pp_rank:05d}.json")
|
||||
save_index_file = os.path.join("tmp_index_files", save_index_file)
|
||||
|
||||
total_size = save_state_dict_shards(sharded_state_dict=state_dict_shard,
|
||||
checkpoint=checkpoint,
|
||||
index_file=index_file,
|
||||
base_filename=states_name,
|
||||
is_master=control_saving)
|
||||
|
||||
if control_saving:
|
||||
assert self.dp_rank == 0 and self.tp_rank == 0, "The saving process should have both dp_rank and tp_rank as 0."
|
||||
index_file.append_meta_data("total_size", total_size)
|
||||
index_file.write_index_file(save_index_file)
|
||||
else:
|
||||
return
|
||||
|
||||
dist.barrier(self.pp_group)
|
||||
|
||||
# The global master rank integrates the index files and clean the folder.
|
||||
if self.pp_rank == 0:
|
||||
|
||||
final_index_file = CheckpointIndexFile(checkpoint)
|
||||
final_index_file.append_meta_data("total_size", 0)
|
||||
|
||||
for filename in os.listdir(tmp_index_file_folder):
|
||||
stage_index_file = CheckpointIndexFile.from_file(os.path.join(tmp_index_file_folder, filename))
|
||||
final_index_file.metadata["total_size"] += stage_index_file.metadata["total_size"]
|
||||
for param_id, state_filename in stage_index_file.weight_map.items():
|
||||
final_index_file.append_weight_map(param_id, state_filename)
|
||||
|
||||
# Store param groups.
|
||||
final_index_file.append_meta_data("param_groups", param_group_file)
|
||||
group_file_path = os.path.join(checkpoint, param_group_file)
|
||||
save_param_groups(optimizer.param_info, group_file_path)
|
||||
|
||||
final_index_file.write_index_file(final_index_file_path)
|
||||
rmtree(tmp_index_file_folder)
|
||||
|
||||
if self.verbose:
|
||||
logging.info(f"The model is split into checkpoint shards. "
|
||||
f"You can find where each parameters has been saved in the "
|
||||
f"index located at {final_index_file_path}.")
|
||||
|
||||
def load_sharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint_index_file: str, prefix: str = ""):
|
||||
"""
|
||||
Load sharded optimizer with the given path to index file of checkpoint folder.
|
||||
|
||||
Args:
|
||||
optimizer (OptimizerWrapper): The optimizer to be loaded.
|
||||
checkpoint_index_file (str): Path to the index file of checkpointing folder.
|
||||
prefix (str): Not used.
|
||||
"""
|
||||
|
||||
def _get_param_id_from_optimizer_param(param: torch.Tensor,
|
||||
master_to_working_map: Optional[Dict[int, torch.Tensor]] = None):
|
||||
if master_to_working_map is not None:
|
||||
working_param = master_to_working_map[id(param)]
|
||||
else:
|
||||
working_param = param
|
||||
return optimizer.param_info['param2id'][id(working_param)]
|
||||
|
||||
# id_map is a mapping from param ids kept by current pipeline, to their corresponding parameter objects.
|
||||
# When Zero is used, the mapped parameter objects should be fp32 master parameters.
|
||||
# IDs should be obtained through saved param2id mapping earlier saved in optimizer.param_info.
|
||||
id_map = {}
|
||||
for pg in optimizer.optim.param_groups:
|
||||
for param in pg['params']:
|
||||
param_id = _get_param_id_from_optimizer_param(param, self.master_to_working_map)
|
||||
id_map[param_id] = param
|
||||
|
||||
# Read checkpoint index file.
|
||||
ckpt_index_file = CheckpointIndexFile.from_file(checkpoint_index_file)
|
||||
ckpt_root_path = ckpt_index_file.root_path
|
||||
weight_map = ckpt_index_file.weight_map
|
||||
weight_map = {int(k): v for k, v in weight_map.items()} # convert saved id from str to int
|
||||
|
||||
# Load param_groups
|
||||
param_group_path = ckpt_index_file.get_param_group_filename()
|
||||
if param_group_path is None:
|
||||
raise RuntimeError(f'Invalid index file path {checkpoint_index_file} for an optimizer. \
|
||||
Lacking param group file under current directory.')
|
||||
saved_groups = torch.load(param_group_path)
|
||||
|
||||
updated_groups = []
|
||||
for old_pg, saved_pg in zip(optimizer.optim.param_groups, saved_groups):
|
||||
# obtain updated param group
|
||||
new_pg = copy.deepcopy(saved_pg)
|
||||
new_pg['params'] = old_pg['params'] # The parameters in the same group shouln't change.
|
||||
updated_groups.append(new_pg)
|
||||
optimizer.optim.__dict__.update({'param_groups': updated_groups})
|
||||
|
||||
# Load saved states to optimizer.
|
||||
# Keep a record of loaded files so that file will not be repeatedly loaded.
|
||||
loaded_file = set()
|
||||
for pg in optimizer.optim.param_groups:
|
||||
for param in pg['params']:
|
||||
if param is None:
|
||||
continue
|
||||
param_id = _get_param_id_from_optimizer_param(param, self.master_to_working_map)
|
||||
if param_id not in weight_map:
|
||||
continue
|
||||
filename = weight_map[param_id]
|
||||
|
||||
# If this param's states has been loaded before, directly return.
|
||||
if filename in loaded_file:
|
||||
continue
|
||||
|
||||
file_path = os.path.join(ckpt_root_path, filename)
|
||||
state_dict = load_shard_state_dict(Path(file_path), use_safetensors=False)
|
||||
load_states_into_optimizer(optimizer.optim, state_dict, id_map, strict=True)
|
||||
loaded_file.add(filename)
|
||||
|
||||
# Then shard the loaded optimizer states if using tp/zero.
|
||||
for param, state in optimizer.optim.state.items():
|
||||
device = param.device
|
||||
if self.master_to_working_map is not None:
|
||||
working_param = self.master_to_working_map[id(param)]
|
||||
else:
|
||||
working_param = param
|
||||
original_shape = optimizer.param_info['param2shape'][id(working_param)]
|
||||
sharded_state = self.shard_from_complete_optimizer_state(state,
|
||||
current_shape=working_param.shape,
|
||||
original_shape=original_shape,
|
||||
device=device,
|
||||
inplace=True)
|
||||
optimizer.optim.state[param] = sharded_state
|
||||
|
||||
sharded_optimizer_loading_epilogue(optimizer.optim)
|
||||
if self.verbose:
|
||||
logging.info(f"The optimizer has been successfully loaded from sharded checkpoint: {ckpt_root_path}.")
|
||||
|
||||
def load_unsharded_model(self, model: nn.Module, checkpoint: str, strict: bool = True):
|
||||
# TODO(Baizhou): support this feature after implementing complete state_dict collection
|
||||
|
@ -314,3 +577,121 @@ class HypridParallelCheckpointIO(GeneralCheckpointIO):
|
|||
"""
|
||||
if self.coordinator.is_master():
|
||||
super().save_lr_scheduler(lr_scheduler, checkpoint)
|
||||
|
||||
def link_master_and_working_param(self, working_to_master_map: Dict[Union[int, torch.Tensor], torch.Tensor],
|
||||
master_to_working_map: Dict[Union[int, torch.Tensor], torch.Tensor]):
|
||||
"""
|
||||
Create mappings between working params (for forward/backward) and master params (for optimizer update) with passed in mappings.
|
||||
This mapping can only be created when mixied precision is used.
|
||||
The created mappings should be mappings from integer parameter addresses to parameter objects.
|
||||
|
||||
Args:
|
||||
working_to_master_map (Dict[Union[int, torch.Tensor], torch.Tensor]): A mapping from working parameters objects/addresses to master parameter objects.
|
||||
master_to_working_map (Dict[Union[int, torch.Tensor], torch.Tensor]): A mapping from master parameters objects/addresses to working parameter objects.
|
||||
"""
|
||||
self.working_to_master_map = dict()
|
||||
for k, v in working_to_master_map.items():
|
||||
if isinstance(k, torch.Tensor):
|
||||
self.working_to_master_map[id(k)] = v
|
||||
elif isinstance(k, int):
|
||||
self.working_to_master_map[k] = v
|
||||
else:
|
||||
raise ValueError(
|
||||
f"The passed in mapping should have keys of type 'int' or 'torch.Tensor', but got {type(k)}!")
|
||||
|
||||
self.master_to_working_map = dict()
|
||||
for k, v in master_to_working_map.items():
|
||||
if isinstance(k, torch.Tensor):
|
||||
self.master_to_working_map[id(k)] = v
|
||||
elif isinstance(k, int):
|
||||
self.master_to_working_map[k] = v
|
||||
else:
|
||||
raise ValueError(
|
||||
f"The passed in mapping should have keys of type 'int' or 'torch.Tensor', but got {type(k)}!")
|
||||
|
||||
@staticmethod
|
||||
def gather_from_sharded_optimizer_state(state: OrderedDict, param: torch.Tensor, original_shape: torch.Size,
|
||||
dp_group: ProcessGroup, tp_group: ProcessGroup, use_zero: bool,
|
||||
inplace: bool) -> OrderedDict:
|
||||
"""
|
||||
With given parameter and its optimizer states, gather the complete optimizer state for saving.
|
||||
|
||||
Args:
|
||||
state (OrderedDict): Optimizer states of given parameter, might be distributed among tp/dp group if using TP/Zero.
|
||||
param (torch.Tensor): The given parameter. It should be working_param when using Zero.
|
||||
original_shape (torch.Size): The size of parameter before sharding.
|
||||
dp_group (ProcessGroup): The process group of data parallel.
|
||||
tp_group (ProcessGroup): The process group of tensor parallel.
|
||||
use_zero (bool): Whether Zero is used.
|
||||
inplace (bool): If set to True, will update the values of argument 'state' in place. Else will make a copy of state.
|
||||
|
||||
Returns:
|
||||
OrderedDict: The complete optimizer state of given parameter.
|
||||
"""
|
||||
dp_size = dist.get_world_size(dp_group)
|
||||
tp_size = dist.get_world_size(tp_group)
|
||||
current_shape = param.shape
|
||||
state_ = state if inplace else copy.deepcopy(state)
|
||||
|
||||
for k, v in state_.items():
|
||||
if isinstance(v, torch.Tensor) and k != 'step':
|
||||
|
||||
# First gather Zero shards.
|
||||
if use_zero:
|
||||
v = v.cuda()
|
||||
gather_tensor = [torch.zeros_like(v) for _ in range(dp_size)]
|
||||
dist.all_gather(gather_tensor, v, group=dp_group)
|
||||
v = torch.stack(gather_tensor).view(-1)[:param.numel()].reshape_as(param)
|
||||
|
||||
# Then gather TP shards.
|
||||
partition_dim = search_tp_partition_dim(current_shape, original_shape, tp_size)
|
||||
if partition_dim is not None:
|
||||
gather_tensor = [torch.zeros_like(v) for _ in range(tp_size)]
|
||||
dist.all_gather(gather_tensor, v, group=tp_group)
|
||||
v = torch.cat(gather_tensor, dim=partition_dim)
|
||||
|
||||
state_[k] = v.detach().clone().cpu()
|
||||
|
||||
return state_
|
||||
|
||||
def shard_from_complete_optimizer_state(self, state: OrderedDict, current_shape: torch.Size,
|
||||
original_shape: torch.Size, device: torch.device,
|
||||
inplace: bool) -> OrderedDict:
|
||||
"""
|
||||
With complete optimizer states of a specific parameter loaded from checkpoint,
|
||||
slice out the sharded optimizer states kept by current device.
|
||||
|
||||
Args:
|
||||
state (OrderedDict): Complete optimizer states of a given parameter, loaded from checkpoint.
|
||||
current_shape (torch.Size): The size of parameter after sharding.
|
||||
original_shape (torch.Size): The size of parameter before sharding.
|
||||
device (torch.device): The destination device of loaded optimizer states.
|
||||
inplace (bool): If set to True, will update the values of argument 'state' in place. Else will make a copy of state.
|
||||
|
||||
Returns:
|
||||
OrderedDict: The sharded optimizer state of the given parameter.
|
||||
"""
|
||||
state_ = state if inplace else copy.deepcopy(state)
|
||||
|
||||
for k, v in state_.items():
|
||||
if isinstance(v, torch.Tensor) and k != 'step':
|
||||
|
||||
# Shard state along tensor parallel group.
|
||||
partition_dim = search_tp_partition_dim(current_shape, original_shape, self.tp_size)
|
||||
if partition_dim is not None:
|
||||
slice_size = current_shape[partition_dim]
|
||||
v = v.split(slice_size, dim=partition_dim)[self.tp_rank]
|
||||
|
||||
# Shard state along data parallel group when using Zero.
|
||||
if self.use_zero:
|
||||
padding_size = (self.dp_size - v.numel() % self.dp_size) % self.dp_size
|
||||
with torch.no_grad():
|
||||
v = v.flatten()
|
||||
if padding_size > 0:
|
||||
v = torch.nn.functional.pad(v, [0, padding_size])
|
||||
slice_size = v.numel() // self.dp_size
|
||||
v = v.split(slice_size, dim=0)[self.dp_rank]
|
||||
|
||||
state_[k] = v.detach().clone().to(device)
|
||||
|
||||
return state_
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
# coding=utf-8
|
||||
import copy
|
||||
import os
|
||||
import re
|
||||
from collections import abc as container_abcs
|
||||
|
@ -8,7 +9,9 @@ from pathlib import Path
|
|||
from typing import Iterator, List, Mapping, Optional, OrderedDict, Tuple
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from torch.distributed import ProcessGroup
|
||||
from torch.optim import Optimizer
|
||||
|
||||
from colossalai.interface import OptimizerWrapper
|
||||
|
@ -93,24 +96,31 @@ def is_safetensor_checkpoint(checkpoint_file_path: str) -> bool:
|
|||
return False
|
||||
|
||||
|
||||
def gather_distributed_param(param: torch.Tensor, keep_vars: bool = False):
|
||||
def search_tp_partition_dim(current_shape: torch.Size, original_shape: torch.Size, tp_size: int) -> Optional[int]:
|
||||
"""
|
||||
Gather the complete parameter for saving if passed in param is distributed.
|
||||
Given the current shape of parameter and the shape of parameter before sharding,
|
||||
return the dimension along which the parameter is sharded when using tensor parallel.
|
||||
If tensor parallel is not used, return None.
|
||||
|
||||
Args:
|
||||
param (torch.Tensor): A model parameter, might be d_tensor.
|
||||
keep_vars (bool, optional): Whether to return the parameter in calculation graph. Defaults to False.
|
||||
current_shape (torch.Size): The current shape of parameter after sharding.
|
||||
original_shape (torch.Size): The shape of parameter before sharding.
|
||||
tp_size (int): The size of tp group.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: the complete parameter
|
||||
Optional[int]: The dimension along which parameter is partitioned.
|
||||
"""
|
||||
param_ = param if keep_vars else param.detach()
|
||||
if is_distributed_tensor(param_):
|
||||
return to_global(param_)
|
||||
elif is_customized_distributed_tensor(param_):
|
||||
return to_global_for_customized_distributed_tensor(param_)
|
||||
else:
|
||||
return param_
|
||||
partition_dim = None
|
||||
for dim, length in enumerate(original_shape):
|
||||
if length > current_shape[dim]:
|
||||
partition_dim = dim
|
||||
break
|
||||
if partition_dim is not None:
|
||||
assert original_shape[partition_dim] == tp_size * current_shape[partition_dim], \
|
||||
f"The parameter isn't evenly distributed among tensor parallel group: \
|
||||
shape before sharding {original_shape}, shape after sharding {current_shape}"
|
||||
|
||||
return partition_dim
|
||||
|
||||
|
||||
# ======================================
|
||||
|
@ -136,7 +146,8 @@ class StateDictSharder:
|
|||
self.current_block = OrderedDict()
|
||||
self.current_block_size = 0
|
||||
|
||||
def append(self, name: str, tensor: torch.Tensor) -> Tuple[Optional[OrderedDict], int]:
|
||||
def append_param(self, name: str, tensor: torch.Tensor) -> Tuple[Optional[OrderedDict], int]:
|
||||
|
||||
tensor_size = calculate_tensor_size(tensor)
|
||||
ret_block = None
|
||||
ret_block_size = 0
|
||||
|
@ -153,6 +164,64 @@ class StateDictSharder:
|
|||
self.current_block_size += tensor_size
|
||||
return ret_block, ret_block_size
|
||||
|
||||
def append_optim_state(self, param_id: int, state: OrderedDict) -> Tuple[Optional[OrderedDict], int]:
|
||||
|
||||
# A state might contain more than one tensors.
|
||||
# e.g. each Adam state includes: 'step', 'exp_avg', 'exp_avg_sq'
|
||||
state_size = 0
|
||||
isDTensor = False
|
||||
for state_tensor in state.values():
|
||||
|
||||
# When state_tensor is not of Tensor class,
|
||||
# e.g., a SGD optimizer with momentum set to 0 can have None as state
|
||||
# The calculation of tensor size should be skipped to avoid error.
|
||||
if not isinstance(state_tensor, torch.Tensor):
|
||||
continue
|
||||
|
||||
# If the states are stored as DTensors, mark isDTensor as true.
|
||||
if is_distributed_tensor(state_tensor):
|
||||
isDTensor = True
|
||||
state_size += calculate_tensor_size(state_tensor)
|
||||
|
||||
ret_block = None
|
||||
ret_block_size = 0
|
||||
|
||||
# directly return if state is stored as distributed tensor
|
||||
if isDTensor:
|
||||
return ret_block, ret_block_size
|
||||
|
||||
# before we return the current block and create a new block,
|
||||
# we need to ensure that the current block is not empty
|
||||
if self.current_block_size + state_size > self.max_shard_size and self.current_block_size > 0:
|
||||
ret_block = self.current_block
|
||||
ret_block_size = self.current_block_size
|
||||
self.current_block = OrderedDict()
|
||||
self.current_block_size = 0
|
||||
|
||||
self.current_block[param_id] = state
|
||||
self.current_block_size += state_size
|
||||
return ret_block, ret_block_size
|
||||
|
||||
|
||||
def gather_distributed_param(param: torch.Tensor, keep_vars: bool = False) -> torch.Tensor:
|
||||
"""
|
||||
Gather the complete parameter for saving if passed in param is distributed under tp setting.
|
||||
|
||||
Args:
|
||||
param (torch.Tensor): A model parameter, might be d_tensor.
|
||||
keep_vars (bool, optional): Whether to return the parameter in calculation graph. Defaults to False.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: the complete parameter
|
||||
"""
|
||||
param_ = param if keep_vars else param.detach()
|
||||
if is_distributed_tensor(param_):
|
||||
return to_global(param_)
|
||||
elif is_customized_distributed_tensor(param_):
|
||||
return to_global_for_customized_distributed_tensor(param_)
|
||||
else:
|
||||
return param_
|
||||
|
||||
|
||||
def save_state_dict_shards(sharded_state_dict: Iterator[Tuple[OrderedDict, int]],
|
||||
checkpoint: str,
|
||||
|
@ -198,28 +267,17 @@ def shard_model_checkpoint(state_dict: torch.Tensor, max_shard_size: int = 1024)
|
|||
Splits a model state dictionary in sub-checkpoints so that the final size of each sub-checkpoint does not exceed a
|
||||
given size.
|
||||
"""
|
||||
current_block = {}
|
||||
current_block_size = 0
|
||||
state_dict_sharder = StateDictSharder(max_shard_size)
|
||||
|
||||
for key, weight in state_dict.items():
|
||||
ret_block = None
|
||||
ret_block_size = 0
|
||||
if not is_distributed_tensor(weight):
|
||||
weight_size = calculate_tensor_size(weight)
|
||||
block, block_size = state_dict_sharder.append_param(key, weight)
|
||||
|
||||
# If this weight is going to tip up over the maximal size, we split.
|
||||
if current_block_size + weight_size > max_shard_size and current_block_size > 0:
|
||||
ret_block = current_block
|
||||
ret_block_size = current_block_size
|
||||
current_block = {}
|
||||
current_block_size = 0
|
||||
current_block[key] = weight
|
||||
current_block_size += weight_size
|
||||
if block != None:
|
||||
yield block, block_size
|
||||
|
||||
if ret_block != None:
|
||||
yield ret_block, ret_block_size
|
||||
|
||||
yield current_block, current_block_size
|
||||
# Return the last block in sharder.
|
||||
yield state_dict_sharder.current_block, state_dict_sharder.current_block_size
|
||||
|
||||
|
||||
def shard_optimizer_checkpoint(state_dict: dict, max_shard_size: int = 1024) -> Iterator[Tuple[OrderedDict, int]]:
|
||||
|
@ -230,212 +288,15 @@ def shard_optimizer_checkpoint(state_dict: dict, max_shard_size: int = 1024) ->
|
|||
|
||||
# Only split state_dict['state']; state_dict['param_group'] is not considered in this function.
|
||||
states = state_dict['state']
|
||||
|
||||
current_block = {}
|
||||
current_block_size = 0
|
||||
state_dict_sharder = StateDictSharder(max_shard_size)
|
||||
|
||||
for param_id, state in states.items():
|
||||
block, block_size = state_dict_sharder.append_optim_state(param_id, state)
|
||||
if block != None:
|
||||
yield block, block_size
|
||||
|
||||
ret_block = None
|
||||
ret_block_size = 0
|
||||
|
||||
# A state might contain more than one tensors.
|
||||
# e.g. each Adam state includes: 'step', 'exp_avg', 'exp_avg_sq'
|
||||
state_size = 0
|
||||
isDTensor = False
|
||||
for state_tensor in state.values():
|
||||
|
||||
# When state_tensor is not of Tensor class,
|
||||
# e.g., a SGD optimizer with momentum set to 0 can have None as state
|
||||
# The calculation of tensor size should be skipped to avoid error.
|
||||
if not isinstance(state_tensor, torch.Tensor):
|
||||
continue
|
||||
|
||||
# If the states are stored as DTensors, mark isDTensor as true.
|
||||
if is_distributed_tensor(state_tensor):
|
||||
isDTensor = True
|
||||
state_size += calculate_tensor_size(state_tensor)
|
||||
|
||||
if not isDTensor:
|
||||
|
||||
if current_block_size + state_size > max_shard_size and current_block_size > 0:
|
||||
ret_block = current_block
|
||||
ret_block_size = current_block_size
|
||||
current_block = {}
|
||||
current_block_size = 0
|
||||
|
||||
current_block[param_id] = state
|
||||
current_block_size += state_size
|
||||
|
||||
if ret_block != None:
|
||||
yield ret_block, ret_block_size
|
||||
|
||||
yield current_block, current_block_size
|
||||
|
||||
|
||||
def load_shard_state_dict(checkpoint_file: Path, use_safetensors: bool = False):
|
||||
"""
|
||||
load shard state dict into model
|
||||
"""
|
||||
if use_safetensors and not checkpoint_file.suffix == ".safetensors":
|
||||
raise Exception("load the model using `safetensors`, but no file endwith .safetensors")
|
||||
if use_safetensors:
|
||||
from safetensors.torch import load_file as safe_load_file
|
||||
from safetensors.torch import safe_open
|
||||
with safe_open(checkpoint_file, framework="pt") as f:
|
||||
metadata = f.metadata()
|
||||
if metadata["format"] != "pt":
|
||||
raise NotImplementedError(
|
||||
f"Conversion from a {metadata['format']} safetensors archive to PyTorch is not implemented yet.")
|
||||
return safe_load_file(checkpoint_file)
|
||||
else:
|
||||
return torch.load(checkpoint_file)
|
||||
|
||||
|
||||
def load_state_dict_into_model(model: nn.Module,
|
||||
state_dict: torch.Tensor,
|
||||
missing_keys: List,
|
||||
strict: bool = False,
|
||||
load_sub_module: bool = True):
|
||||
r"""Copies parameters and buffers from :attr:`state_dict` into
|
||||
this module and its descendants.
|
||||
|
||||
Args:
|
||||
state_dict (dict): a dict containing parameters and
|
||||
persistent buffers.
|
||||
"""
|
||||
if not isinstance(state_dict, Mapping):
|
||||
raise TypeError("Expected state_dict to be dict-like, got {}.".format(type(state_dict)))
|
||||
|
||||
unexpected_keys: List[str] = []
|
||||
sub_missing_keys: List[str] = []
|
||||
error_msgs: List[str] = []
|
||||
|
||||
# copy state_dict so _load_from_state_dict can modify it
|
||||
metadata = getattr(state_dict, '_metadata', None)
|
||||
state_dict = OrderedDict(state_dict)
|
||||
if metadata is not None:
|
||||
state_dict._metadata = metadata
|
||||
|
||||
def load(module: nn.Module, state_dict, prefix="", load_sub_module: bool = True):
|
||||
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
|
||||
args = (state_dict, prefix, local_metadata, True, sub_missing_keys, [], error_msgs)
|
||||
# Parameters of module and children will start with prefix. We can exit early if there are none in this
|
||||
# state_dict
|
||||
if len([key for key in state_dict if key.startswith(prefix)]) > 0:
|
||||
module._load_from_state_dict(*args)
|
||||
if load_sub_module:
|
||||
for name, child in module._modules.items():
|
||||
if child is not None:
|
||||
load(child, state_dict, prefix + name + ".")
|
||||
|
||||
load(model, state_dict, "", load_sub_module)
|
||||
del load
|
||||
|
||||
missing_keys = missing_keys.append(sub_missing_keys)
|
||||
|
||||
if strict:
|
||||
if len(unexpected_keys) > 0:
|
||||
error_msgs = 'Unexpected key(s) in state_dict: {}. '.format(', '.join(
|
||||
'"{}"'.format(k) for k in unexpected_keys))
|
||||
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
|
||||
model.__class__.__name__, "\n\t".join(error_msgs)))
|
||||
|
||||
|
||||
def load_param_groups_into_optimizer(optimizer: Optimizer, param_group_path: str) -> dict:
|
||||
"""
|
||||
Load information of param_groups into an initialized optimizer.
|
||||
"""
|
||||
|
||||
# Load list of param_groups from given file path.
|
||||
# The params in saved_groups are in the form of integer indices.
|
||||
saved_groups = torch.load(param_group_path)
|
||||
if not isinstance(saved_groups, List):
|
||||
raise ValueError(f'The param_groups saved at {param_group_path} is not of List type')
|
||||
|
||||
# The params in param_groups are in the form of pytorch tensors.
|
||||
# For more details, please view source code of Optimizer class in pytorch.
|
||||
param_groups = optimizer.param_groups
|
||||
|
||||
# Check the compatibility of saved_groups and param_groups.
|
||||
if len(param_groups) != len(saved_groups):
|
||||
raise ValueError("loaded state dict has a different number of original parameter groups")
|
||||
param_lens = (len(g['params']) for g in param_groups)
|
||||
saved_lens = (len(g['params']) for g in saved_groups)
|
||||
if any(p_len != s_len for p_len, s_len in zip(param_lens, saved_lens)):
|
||||
raise ValueError("loaded state dict contains a parameter group "
|
||||
"that doesn't match the size of optimizer's group")
|
||||
|
||||
# Creating mapping from id to parameters.
|
||||
id_map = {
|
||||
old_id: p for old_id, p in zip(chain.from_iterable((g['params'] for g in saved_groups
|
||||
)), chain.from_iterable((g['params'] for g in param_groups)))
|
||||
}
|
||||
|
||||
# Update parameter groups, setting their 'params' value.
|
||||
def update_group(group, new_group):
|
||||
new_group['params'] = group['params']
|
||||
return new_group
|
||||
|
||||
updated_groups = [update_group(g, ng) for g, ng in zip(param_groups, saved_groups)]
|
||||
|
||||
optimizer.__dict__.update({'param_groups': updated_groups})
|
||||
return id_map
|
||||
|
||||
|
||||
def load_states_into_optimizer(optimizer: Optimizer, state_dict: dict, id_map: dict):
|
||||
r"""Copies states from `state_dict` into an Optimizer object.
|
||||
|
||||
Args:
|
||||
optimizer(Optimizer): An initialized Optimizer object to be loaded
|
||||
state_dict(dict): a mapping from tensor index (an integer)
|
||||
to its states to be loaded (a mapping from state name to a tensor).
|
||||
id_map(dict): a mapping from tensor index (an integer)
|
||||
to its corresponding parameter (a tensor) whose states will be updated.
|
||||
"""
|
||||
|
||||
def cast(param, value, key=None):
|
||||
r"""Make a deep copy of value, casting all tensors to device of param."""
|
||||
if isinstance(value, torch.Tensor):
|
||||
# Floating-point types are a bit special here. They are the only ones
|
||||
# that are assumed to always match the type of params.
|
||||
# Make sure state['step'] is not casted https://github.com/pytorch/pytorch/issues/74424
|
||||
if (key != "step"):
|
||||
if param.is_floating_point():
|
||||
value = value.to(param.dtype)
|
||||
value = value.to(param.device)
|
||||
return value
|
||||
elif isinstance(value, dict):
|
||||
return {k: cast(param, v, key=k) for k, v in value.items()}
|
||||
elif isinstance(value, container_abcs.Iterable):
|
||||
return type(value)(cast(param, v) for v in value)
|
||||
else:
|
||||
return value
|
||||
|
||||
# Copy state assigned to params (and cast tensors to appropriate types).
|
||||
# State that is not assigned to params is copied as is (needed for
|
||||
# backward compatibility).
|
||||
new_states = defaultdict(dict)
|
||||
for k, v in state_dict.items():
|
||||
if k in id_map:
|
||||
param = id_map[k]
|
||||
new_states[param] = cast(param, v)
|
||||
else:
|
||||
new_states[k] = v
|
||||
|
||||
optimizer.state.update(new_states)
|
||||
|
||||
|
||||
def sharded_optimizer_loading_epilogue(optimizer: Optimizer):
|
||||
r"""Do the cleaning up work after state_dict has been loaded into optimizer
|
||||
|
||||
Args:
|
||||
optimizer(Optimizer): An optimizer object whose state has just been loaded.
|
||||
"""
|
||||
|
||||
# Do the cleaning up as in src code of Pytorch.
|
||||
optimizer._hook_for_profile() # To support multiprocessing pickle/unpickle.
|
||||
optimizer.defaults.setdefault('differentiable', False)
|
||||
# Return the last block in sharder.
|
||||
yield state_dict_sharder.current_block, state_dict_sharder.current_block_size
|
||||
|
||||
|
||||
# ======================================
|
||||
|
@ -565,38 +426,180 @@ def generate_dtensor_file_name(param_name: str, index: int, use_safetensors: boo
|
|||
return f'{param_name}.{index}.{suffix}'
|
||||
|
||||
|
||||
def save_state_dict_as_shard(
|
||||
state_dict: dict,
|
||||
checkpoint_path: str,
|
||||
index: int,
|
||||
total_number: int,
|
||||
use_safetensors: bool,
|
||||
prefix: str = None,
|
||||
) -> None:
|
||||
"""
|
||||
Save state dict as shard.
|
||||
|
||||
Args:
|
||||
state_dict (dict): state dict.
|
||||
checkpoint_path (str): path to the checkpoint file.
|
||||
index (int): index of the shard.
|
||||
total_number (int): total number of shards.
|
||||
prefix (str): prefix of the shard file name.
|
||||
use_safetensors (bool): whether to use safetensors to save the checkpoint.
|
||||
"""
|
||||
# generate the shard name
|
||||
shard_file_name = generate_checkpoint_shard_file_name(index, total_number, use_safetensors, prefix)
|
||||
shard_file_path = Path(checkpoint_path).joinpath(shard_file_name).absolute()
|
||||
|
||||
# save the shard
|
||||
save_state_dict(state_dict, str(shard_file_path), use_safetensors)
|
||||
|
||||
|
||||
# ========================================
|
||||
# Helper functions for loading state dict
|
||||
# ========================================
|
||||
|
||||
|
||||
def load_shard_state_dict(checkpoint_file: Path, use_safetensors: bool = False):
|
||||
"""
|
||||
load shard state dict into model
|
||||
"""
|
||||
if use_safetensors and not checkpoint_file.suffix == ".safetensors":
|
||||
raise Exception("load the model using `safetensors`, but no file endwith .safetensors")
|
||||
if use_safetensors:
|
||||
from safetensors.torch import load_file as safe_load_file
|
||||
from safetensors.torch import safe_open
|
||||
with safe_open(checkpoint_file, framework="pt") as f:
|
||||
metadata = f.metadata()
|
||||
if metadata["format"] != "pt":
|
||||
raise NotImplementedError(
|
||||
f"Conversion from a {metadata['format']} safetensors archive to PyTorch is not implemented yet.")
|
||||
return safe_load_file(checkpoint_file)
|
||||
else:
|
||||
return torch.load(checkpoint_file)
|
||||
|
||||
|
||||
def load_state_dict_into_model(model: nn.Module,
|
||||
state_dict: torch.Tensor,
|
||||
missing_keys: List,
|
||||
strict: bool = False,
|
||||
load_sub_module: bool = True):
|
||||
r"""Copies parameters and buffers from :attr:`state_dict` into
|
||||
this module and its descendants.
|
||||
|
||||
Args:
|
||||
state_dict (dict): a dict containing parameters and
|
||||
persistent buffers.
|
||||
"""
|
||||
if not isinstance(state_dict, Mapping):
|
||||
raise TypeError("Expected state_dict to be dict-like, got {}.".format(type(state_dict)))
|
||||
|
||||
unexpected_keys: List[str] = []
|
||||
sub_missing_keys: List[str] = []
|
||||
error_msgs: List[str] = []
|
||||
|
||||
# copy state_dict so _load_from_state_dict can modify it
|
||||
metadata = getattr(state_dict, '_metadata', None)
|
||||
state_dict = OrderedDict(state_dict)
|
||||
if metadata is not None:
|
||||
state_dict._metadata = metadata
|
||||
|
||||
def load(module: nn.Module, state_dict, prefix="", load_sub_module: bool = True):
|
||||
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
|
||||
args = (state_dict, prefix, local_metadata, True, sub_missing_keys, [], error_msgs)
|
||||
# Parameters of module and children will start with prefix. We can exit early if there are none in this
|
||||
# state_dict
|
||||
if len([key for key in state_dict if key.startswith(prefix)]) > 0:
|
||||
module._load_from_state_dict(*args)
|
||||
if load_sub_module:
|
||||
for name, child in module._modules.items():
|
||||
if child is not None:
|
||||
load(child, state_dict, prefix + name + ".")
|
||||
|
||||
load(model, state_dict, "", load_sub_module)
|
||||
del load
|
||||
|
||||
missing_keys = missing_keys.append(sub_missing_keys)
|
||||
|
||||
if strict:
|
||||
if len(unexpected_keys) > 0:
|
||||
error_msgs = 'Unexpected key(s) in state_dict: {}. '.format(', '.join(
|
||||
'"{}"'.format(k) for k in unexpected_keys))
|
||||
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
|
||||
model.__class__.__name__, "\n\t".join(error_msgs)))
|
||||
|
||||
|
||||
def load_param_groups_into_optimizer(optimizer: Optimizer, param_group_path: str) -> dict:
|
||||
"""
|
||||
Load information of param_groups into an initialized optimizer.
|
||||
"""
|
||||
|
||||
# Load list of param_groups from given file path.
|
||||
# The params in saved_groups are in the form of integer indices.
|
||||
saved_groups = torch.load(param_group_path)
|
||||
if not isinstance(saved_groups, List):
|
||||
raise ValueError(f'The param_groups saved at {param_group_path} is not of List type')
|
||||
|
||||
# The params in param_groups are in the form of pytorch tensors.
|
||||
# For more details, please view source code of Optimizer class in pytorch.
|
||||
param_groups = optimizer.param_groups
|
||||
|
||||
# Check the compatibility of saved_groups and param_groups.
|
||||
if len(param_groups) != len(saved_groups):
|
||||
raise ValueError("loaded state dict has a different number of original parameter groups")
|
||||
param_lens = (len(g['params']) for g in param_groups)
|
||||
saved_lens = (len(g['params']) for g in saved_groups)
|
||||
if any(p_len != s_len for p_len, s_len in zip(param_lens, saved_lens)):
|
||||
raise ValueError("loaded state dict contains a parameter group "
|
||||
"that doesn't match the size of optimizer's group")
|
||||
|
||||
# Creating mapping from id to parameters.
|
||||
id_map = {
|
||||
old_id: p for old_id, p in zip(chain.from_iterable((g['params'] for g in saved_groups
|
||||
)), chain.from_iterable((g['params'] for g in param_groups)))
|
||||
}
|
||||
|
||||
# Update parameter groups, setting their 'params' value.
|
||||
def update_group(group, new_group):
|
||||
new_group['params'] = group['params']
|
||||
return new_group
|
||||
|
||||
updated_groups = [update_group(g, ng) for g, ng in zip(param_groups, saved_groups)]
|
||||
|
||||
optimizer.__dict__.update({'param_groups': updated_groups})
|
||||
return id_map
|
||||
|
||||
|
||||
def load_states_into_optimizer(optimizer: Optimizer, state_dict: dict, id_map: dict, strict: bool = False):
|
||||
r"""Copies states from `state_dict` into an Optimizer object.
|
||||
|
||||
Args:
|
||||
optimizer(Optimizer): An initialized Optimizer object to be loaded
|
||||
state_dict(dict): A mapping from tensor index (an integer)
|
||||
to its states to be loaded (a mapping from state name to a tensor).
|
||||
id_map(dict): A mapping from tensor index (an integer)
|
||||
to its corresponding parameter (a tensor) whose states will be updated.
|
||||
strict(bool, optional): If set to True, only load the parameters with its id in id_map. Defaults to False.
|
||||
"""
|
||||
|
||||
# Ensure that the keys of state_dict are integers.
|
||||
state_dict = {int(k): v for k, v in state_dict.items()}
|
||||
|
||||
def cast(param, value, key=None):
|
||||
r"""Make a deep copy of value, casting all tensors to device of param."""
|
||||
if isinstance(value, torch.Tensor):
|
||||
# Floating-point types are a bit special here. They are the only ones
|
||||
# that are assumed to always match the type of params.
|
||||
# Make sure state['step'] is not casted https://github.com/pytorch/pytorch/issues/74424
|
||||
if (key != "step"):
|
||||
if param.is_floating_point():
|
||||
value = value.to(param.dtype)
|
||||
value = value.to(param.device)
|
||||
return value
|
||||
elif isinstance(value, dict):
|
||||
return {k: cast(param, v, key=k) for k, v in value.items()}
|
||||
elif isinstance(value, container_abcs.Iterable):
|
||||
return type(value)(cast(param, v) for v in value)
|
||||
else:
|
||||
return value
|
||||
|
||||
# Copy state assigned to params (and cast tensors to appropriate types).
|
||||
# State that is not assigned to params is copied as is (needed for
|
||||
# backward compatibility).
|
||||
new_states = defaultdict(dict)
|
||||
for k, v in state_dict.items():
|
||||
if k in id_map:
|
||||
param = id_map[k]
|
||||
new_states[param] = cast(param, v)
|
||||
elif not strict:
|
||||
new_states[k] = v
|
||||
|
||||
optimizer.state.update(new_states)
|
||||
|
||||
|
||||
def sharded_optimizer_loading_epilogue(optimizer: Optimizer):
|
||||
r"""Do the cleaning up work after state_dict has been loaded into optimizer
|
||||
|
||||
Args:
|
||||
optimizer(Optimizer): An optimizer object whose state has just been loaded.
|
||||
"""
|
||||
|
||||
# Do the cleaning up as in src code of Pytorch.
|
||||
optimizer._hook_for_profile() # To support multiprocessing pickle/unpickle.
|
||||
optimizer.defaults.setdefault('differentiable', False)
|
||||
|
||||
|
||||
def has_index_file(checkpoint_path: str) -> Tuple[bool, Optional[Path]]:
|
||||
"""
|
||||
Check whether the checkpoint has an index file.
|
||||
|
|
|
@ -679,7 +679,7 @@ class ZeroDDP(ColoDDP):
|
|||
gathered_param_buffer.update(self._get_chunk_to_save_data(chunk, only_rank_0, dtype))
|
||||
gathered_param = gathered_param_buffer.pop(fp32_param)
|
||||
|
||||
block, block_size = sharder.append(prefix + name, gathered_param)
|
||||
block, block_size = sharder.append_param(prefix + name, gathered_param)
|
||||
if block is not None:
|
||||
yield block, block_size
|
||||
|
||||
|
@ -690,7 +690,7 @@ class ZeroDDP(ColoDDP):
|
|||
for name, buf in self.named_buffers():
|
||||
if buf is not None and name not in self._non_persistent_buffers_set:
|
||||
buffer = buf if keep_vars else buf.detach()
|
||||
block, block_size = sharder.append(prefix + name, buffer)
|
||||
block, block_size = sharder.append_param(prefix + name, buffer)
|
||||
if block is not None:
|
||||
yield block, block_size
|
||||
# save extra states
|
||||
|
@ -698,7 +698,7 @@ class ZeroDDP(ColoDDP):
|
|||
if getattr(self.__class__, "get_extra_state",
|
||||
torch.nn.Module.get_extra_state) is not torch.nn.Module.get_extra_state:
|
||||
extra_state = self.get_extra_state()
|
||||
block, block_size = sharder.append(extra_state_key, extra_state)
|
||||
block, block_size = sharder.append_param(extra_state_key, extra_state)
|
||||
if block is not None:
|
||||
yield block, block_size
|
||||
|
||||
|
|
|
@ -10,7 +10,7 @@ from torch.nn import Parameter
|
|||
from torch.optim import Optimizer
|
||||
|
||||
from colossalai.amp.naive_amp.mixed_precision_mixin import BF16MixedPrecisionMixin, FP16MixedPrecisionMixin
|
||||
from colossalai.checkpoint_io.utils import calculate_tensor_size
|
||||
from colossalai.checkpoint_io.utils import StateDictSharder
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.nn.optimizer import ColossalaiOptimizer, CPUAdam, FusedAdam, HybridAdam
|
||||
from colossalai.tensor.d_tensor import is_distributed_tensor
|
||||
|
@ -691,49 +691,17 @@ class ZeroOptimizer(ColossalaiOptimizer):
|
|||
Iterator[OrderedDict]: A generator of state dict shard of optimizer states.
|
||||
"""
|
||||
|
||||
current_block = {}
|
||||
current_block_size = 0
|
||||
|
||||
sharder = StateDictSharder(max_shard_size)
|
||||
for param_id in self.id_to_real_params.keys():
|
||||
|
||||
dist.barrier()
|
||||
state = self.collect_states(param_id=param_id, only_rank_0=only_rank_0)
|
||||
|
||||
ret_block = None
|
||||
ret_block_size = 0
|
||||
block, block_size = sharder.append_optim_state(param_id, state)
|
||||
if block is not None:
|
||||
yield block, block_size
|
||||
|
||||
# A state might contain more than one tensors.
|
||||
# e.g. each Adam state includes: 'step', 'exp_avg', 'exp_avg_sq'
|
||||
state_size = 0
|
||||
isDTensor = False
|
||||
for state_tensor in state.values():
|
||||
|
||||
# When state_tensor is not of Tensor class,
|
||||
# e.g., a SGD optimizer with momentum set to 0 can have None as state
|
||||
# The calculation of tensor size should be skipped to avoid error.
|
||||
if not isinstance(state_tensor, torch.Tensor):
|
||||
continue
|
||||
|
||||
# If the states are stored as DTensors, mark isDTensor as true.
|
||||
if is_distributed_tensor(state_tensor):
|
||||
isDTensor = True
|
||||
state_size += calculate_tensor_size(state_tensor)
|
||||
|
||||
if not isDTensor:
|
||||
|
||||
if current_block_size + state_size > max_shard_size and current_block_size > 0:
|
||||
ret_block = current_block
|
||||
ret_block_size = current_block_size
|
||||
current_block = {}
|
||||
current_block_size = 0
|
||||
|
||||
current_block[param_id] = state
|
||||
current_block_size += state_size
|
||||
|
||||
if ret_block != None:
|
||||
yield ret_block, ret_block_size
|
||||
|
||||
yield current_block, current_block_size
|
||||
yield sharder.current_block, sharder.current_block_size
|
||||
|
||||
|
||||
class GeminiAdamOptimizer(ZeroOptimizer):
|
||||
|
|
|
@ -10,6 +10,7 @@ from colossalai.booster.plugin import HybridParallelPlugin
|
|||
from colossalai.shardformer.layer.utils import Randomizer
|
||||
from colossalai.tensor.d_tensor.api import clear_layout_converter
|
||||
from colossalai.testing import (
|
||||
assert_close_loose,
|
||||
check_state_dict_equal,
|
||||
clear_cache_before_run,
|
||||
parameterize,
|
||||
|
@ -19,34 +20,34 @@ from colossalai.testing import (
|
|||
from tests.kit.model_zoo import model_zoo
|
||||
|
||||
|
||||
# TODO (Baizhou): Add test cases for shard=False
|
||||
@clear_cache_before_run()
|
||||
@parameterize('shard', [True])
|
||||
@parameterize('model_name', ['transformers_gpt'])
|
||||
@parameterize('size_per_shard', [32])
|
||||
@parameterize('test_config', [{
|
||||
'tp_size': 2,
|
||||
'pp_size': 2,
|
||||
'num_microbatches': 4,
|
||||
'precision': 'fp32',
|
||||
}, {
|
||||
'tp_size': 1,
|
||||
'pp_size': 2,
|
||||
'num_microbatches': 4,
|
||||
'precision': 'fp32',
|
||||
}, {
|
||||
'tp_size': 4,
|
||||
'pp_size': 1,
|
||||
'precision': 'fp32',
|
||||
}, {
|
||||
'tp_size': 2,
|
||||
'pp_size': 1,
|
||||
'precision': 'fp32',
|
||||
'pp_size': 2,
|
||||
'num_microbatches': 4,
|
||||
'precision': 'fp16',
|
||||
'initial_scale': 1
|
||||
}, {
|
||||
'tp_size': 2,
|
||||
'pp_size': 1,
|
||||
'zero_stage': 2,
|
||||
'precision': 'fp16',
|
||||
'initial_scale': 1
|
||||
}, {
|
||||
'tp_size': 1,
|
||||
'pp_size': 2,
|
||||
'num_microbatches': 4,
|
||||
'zero_stage': 1,
|
||||
'precision': 'fp16',
|
||||
'initial_scale': 1
|
||||
}])
|
||||
def exam_state_dict(shard: bool, model_name: str, size_per_shard: int, test_config: dict):
|
||||
|
||||
|
@ -61,46 +62,91 @@ def exam_state_dict(shard: bool, model_name: str, size_per_shard: int, test_conf
|
|||
loss = criterion(outputs)
|
||||
return loss
|
||||
|
||||
def _preprocess_data(data):
|
||||
if booster.plugin.stage_manager is not None:
|
||||
for k, v in data.items():
|
||||
if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__:
|
||||
new_shape = [1] * v.dim()
|
||||
new_shape[0] = 4
|
||||
data[k] = v.to('cuda').repeat(*new_shape)
|
||||
return iter([data])
|
||||
else:
|
||||
return {k: v.cuda() for k, v in data.items()}
|
||||
|
||||
model = model_fn().cuda()
|
||||
optimizer = Adam(model.parameters(), lr=1e-3)
|
||||
model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion)
|
||||
|
||||
new_model = model_fn().cuda()
|
||||
new_optimizer = Adam(new_model.parameters(), lr=1e-3)
|
||||
new_model, new_optimizer, criterion, _, _ = booster.boost(new_model, new_optimizer, criterion)
|
||||
|
||||
data = data_gen_fn()
|
||||
model.train()
|
||||
if booster.plugin.stage_manager is not None:
|
||||
for k, v in data.items():
|
||||
if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__:
|
||||
new_shape = [1] * v.dim()
|
||||
new_shape[0] = 4
|
||||
data[k] = v.to('cuda').repeat(*new_shape)
|
||||
data_iter = iter([data])
|
||||
output = booster.execute_pipeline(data_iter,
|
||||
model,
|
||||
_criterion,
|
||||
optimizer,
|
||||
return_loss=True,
|
||||
return_outputs=False)
|
||||
booster.execute_pipeline(_preprocess_data(data),
|
||||
model,
|
||||
_criterion,
|
||||
optimizer,
|
||||
return_loss=True,
|
||||
return_outputs=False)
|
||||
else:
|
||||
data = {k: v.cuda() for k, v in data.items()}
|
||||
output = model(**data)
|
||||
output = model(**_preprocess_data(data))
|
||||
loss = criterion(output)
|
||||
optimizer.backward(loss)
|
||||
|
||||
optimizer.step()
|
||||
|
||||
with shared_tempdir() as tempdir:
|
||||
|
||||
model_ckpt_path = f"{tempdir}/model"
|
||||
# optimizer_ckpt_path = f"{tempdir}/optimizer"
|
||||
optimizer_ckpt_path = f"{tempdir}/optimizer"
|
||||
booster.save_model(model, model_ckpt_path, shard=shard, size_per_shard=size_per_shard)
|
||||
# booster.save_optimizer(optimizer, optimizer_ckpt_path, shard=shard, size_per_shard=size_per_shard)
|
||||
booster.save_optimizer(optimizer, optimizer_ckpt_path, shard=shard, size_per_shard=size_per_shard)
|
||||
dist.barrier()
|
||||
|
||||
new_model = model_fn().cuda()
|
||||
new_optimizer = Adam(new_model.parameters(), lr=1e-3)
|
||||
new_model, new_optimizer, criterion, _, _ = booster.boost(new_model, new_optimizer, criterion)
|
||||
|
||||
booster.load_model(new_model, model_ckpt_path)
|
||||
check_state_dict_equal(model.unwrap().state_dict(), new_model.unwrap().state_dict(), False)
|
||||
booster.load_optimizer(new_optimizer, optimizer_ckpt_path)
|
||||
check_state_dict_equal(optimizer.unwrap().state_dict(), new_optimizer.unwrap().state_dict(), False)
|
||||
dist.barrier()
|
||||
|
||||
# Check whether the loaded model & optimizer works smoothly.
|
||||
model.train()
|
||||
new_model.train()
|
||||
if booster.plugin.stage_manager is not None:
|
||||
booster.execute_pipeline(_preprocess_data(data),
|
||||
model,
|
||||
_criterion,
|
||||
optimizer,
|
||||
return_loss=True,
|
||||
return_outputs=False)
|
||||
booster.execute_pipeline(_preprocess_data(data),
|
||||
new_model,
|
||||
_criterion,
|
||||
new_optimizer,
|
||||
return_loss=True,
|
||||
return_outputs=False)
|
||||
else:
|
||||
old_model_loss = criterion(model(**_preprocess_data(data)))
|
||||
optimizer.backward(old_model_loss)
|
||||
new_model_loss = criterion(new_model(**_preprocess_data(data)))
|
||||
new_optimizer.backward(new_model_loss)
|
||||
|
||||
optimizer.step()
|
||||
new_optimizer.step()
|
||||
|
||||
# Check updated weights.
|
||||
stage_manager = booster.plugin.stage_manager
|
||||
|
||||
if stage_manager is None or stage_manager.is_first_stage():
|
||||
assert_close_loose(model.unwrap().wte.weight.data, new_model.unwrap().wte.weight.data, atol=5e-3, rtol=5e-3)
|
||||
assert_close_loose(model.unwrap().h[0].mlp.c_fc.weight.data,
|
||||
new_model.unwrap().h[0].mlp.c_fc.weight.data,
|
||||
atol=5e-3,
|
||||
rtol=5e-3)
|
||||
|
||||
dist.barrier()
|
||||
Randomizer.reset_index()
|
||||
clear_layout_converter()
|
||||
|
||||
|
|
Loading…
Reference in New Issue