mirror of https://github.com/hpcaitech/ColossalAI
Next commit [checkpointio] Unsharded Optimizer Checkpoint for Gemini Plugin (#4141)
* [checkpointio] unsharded optimizer checkpoint for Gemini plugin * [checkpointio] unsharded optimizer checkpoint for Gemini using all_gatherpull/4217/head
parent
fee32a3b78
commit
58913441a1
|
@ -33,44 +33,40 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
|
|||
super().__init__()
|
||||
self.coordinator = DistCoordinator()
|
||||
|
||||
def load_unsharded_model(self, model: GeminiDDP, checkpoint: str, strict: bool = True):
|
||||
"""
|
||||
Load model from checkpoint with automatic unwrapping.
|
||||
"""
|
||||
# the model should be unwrapped in self.load_model via ModelWrapper.unwrap
|
||||
return super().load_unsharded_model(model, checkpoint, strict=strict)
|
||||
|
||||
def save_unsharded_model(self, model: GeminiDDP, checkpoint: str, gather_dtensor: bool, use_safetensors: bool):
|
||||
"""
|
||||
Save model to checkpoint but only on master process.
|
||||
Save sharded model to checkpoint but only on master process.
|
||||
The model should be unwrapped in self.load_model via ModelWrapper.unwrap.
|
||||
As there is communication when getting state dict, this must be called on all processes.
|
||||
"""
|
||||
# the model should be unwrapped in self.load_model via ModelWrapper.unwrap
|
||||
# as there is communication when get state dict, this must be called on all processes
|
||||
state_dict = model.state_dict(only_rank_0=True)
|
||||
if self.coordinator.is_master():
|
||||
save_state_dict(state_dict, checkpoint, use_safetensors)
|
||||
|
||||
def load_unsharded_model(self, model: GeminiDDP, checkpoint: str, strict: bool = True):
|
||||
"""
|
||||
Load model from checkpoint with automatic unwrapping.
|
||||
The model should be unwrapped in self.load_model via ModelWrapper.unwrap.
|
||||
"""
|
||||
super().load_unsharded_model(model, checkpoint, strict=strict)
|
||||
|
||||
def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool):
|
||||
"""
|
||||
Save optimizer to checkpoint but only on master process.
|
||||
"""
|
||||
# TODO(ver217): optimizer state dict is sharded
|
||||
warnings.warn('GeminiPlugin does not support save full optimizer checkpoint now. Save it on every process.')
|
||||
checkpoint = f'{checkpoint}.rank{self.coordinator.rank}'
|
||||
super().save_unsharded_optimizer(optimizer, checkpoint, gather_dtensor)
|
||||
|
||||
def load_optimizer(self, optimizer: Optimizer, checkpoint: str):
|
||||
warnings.warn(
|
||||
'GeminiPlugin can only load optimizer checkpoint saved by itself with the same number of processes.')
|
||||
checkpoint = f'{checkpoint}.rank{self.coordinator.rank}'
|
||||
super().load_optimizer(optimizer, checkpoint)
|
||||
|
||||
def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str):
|
||||
"""
|
||||
Save model to checkpoint but only on master process.
|
||||
Save unsharded optimizer state dict to checkpoint.
|
||||
After calling optimizer.state_dict(), the complete optimizer states will be collected on master rank.
|
||||
As there is communication when getting state dict, this must be called on all processes.
|
||||
The saving process will only be executed by master rank.
|
||||
"""
|
||||
state_dict = optimizer.state_dict()
|
||||
if self.coordinator.is_master():
|
||||
super().save_lr_scheduler(lr_scheduler, checkpoint)
|
||||
save_state_dict(state_dict, checkpoint, use_safetensors=False)
|
||||
|
||||
def load_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: str):
|
||||
"""
|
||||
Loading unsharded optimizer from checkpoint file.
|
||||
For each process, only loading optimizer states of parameters it controls.
|
||||
"""
|
||||
super().load_unsharded_optimizer(optimizer, checkpoint)
|
||||
|
||||
def save_sharded_model(self,
|
||||
model: GeminiDDP,
|
||||
|
@ -82,6 +78,12 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
|
|||
"""
|
||||
Save sharded model
|
||||
"""
|
||||
if os.path.isfile(checkpoint_path):
|
||||
logging.error(f"Provided path ({checkpoint_path}) should be a directory, not a file")
|
||||
return
|
||||
|
||||
Path(checkpoint_path).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
state_dict_shard = model.state_dict_shard(max_shard_size=max_shard_size, only_rank_0=True, dtype=torch.float32)
|
||||
weights_name, save_index_file = get_model_base_filenames(prefix, use_safetensors)
|
||||
total_size = 0
|
||||
|
@ -117,6 +119,23 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
|
|||
"""
|
||||
return super().load_sharded_model(model, checkpoint_index_file, strict, use_safetensors, load_sub_module=False)
|
||||
|
||||
def save_sharded_optimizer(self, optimizer: Optimizer, checkpoint: Path, gather_dtensor: bool, prefix: str,
|
||||
size_per_shard: int):
|
||||
"""
|
||||
Save sharded optimizer state dict to checkpoint folder.
|
||||
As there is communication when getting state dict, this must be called on all processes.
|
||||
"""
|
||||
Path(checkpoint).mkdir(parents=True, exist_ok=True)
|
||||
super().save_sharded_optimizer(optimizer, checkpoint, gather_dtensor, prefix, size_per_shard)
|
||||
|
||||
def load_sharded_optimizer(self, optimizer: Optimizer, checkpoint_index_file: Path, prefix: str):
|
||||
"""
|
||||
Loading sharded optimizer from checkpoint folder, with index file given.
|
||||
For each process, only loading optimizer states of parameters it controls.
|
||||
"""
|
||||
# TODO(Baizhou): To be implemented.
|
||||
pass
|
||||
|
||||
|
||||
class GeminiModel(ModelWrapper):
|
||||
|
||||
|
@ -193,7 +212,7 @@ class GeminiPlugin(DPPluginBase):
|
|||
which will be used when using hybrid CPU optimizer.
|
||||
This argument is meaningless when `placement_policy` of `GeminiManager` is not "auto".
|
||||
Defaults to 0.0.
|
||||
initial_scale (float, optional): Initial scale used by DynamicGradScaler. Defaults to 2**32.
|
||||
initial_scale (float, optional): Initial scale used by DynamicGradScaler. Defaults to 2**16.
|
||||
min_scale (float, optional): Min scale used by DynamicGradScaler. Defaults to 1.
|
||||
growth_factor (float, optional): growth_factor used by DynamicGradScaler. Defaults to 2.
|
||||
backoff_factor (float, optional): backoff_factor used by DynamicGradScaler. Defaults to 0.5.
|
||||
|
@ -219,7 +238,7 @@ class GeminiPlugin(DPPluginBase):
|
|||
min_chunk_size_m: float = 32,
|
||||
memstats: Optional[MemStats] = None,
|
||||
gpu_margin_mem_ratio: float = 0.0,
|
||||
initial_scale: float = 2**32,
|
||||
initial_scale: float = 2**16,
|
||||
min_scale: float = 1,
|
||||
growth_factor: float = 2,
|
||||
backoff_factor: float = 0.5,
|
||||
|
|
|
@ -152,6 +152,7 @@ class CheckpointIO(ABC):
|
|||
names to compose the keys in state_dict. Defaults to None.
|
||||
size_per_shard (int, optional): Maximum size of checkpoint shard file in MB. This is useful only when ``shard=True``. Defaults to 1024.
|
||||
"""
|
||||
|
||||
index_file_exists, index_file_path = has_index_file(checkpoint)
|
||||
|
||||
if Path(checkpoint).is_dir() and not index_file_exists:
|
||||
|
@ -186,6 +187,7 @@ class CheckpointIO(ABC):
|
|||
prefix (str): prefix for the optimizer checkpoint when shard = True. Default: None.
|
||||
size_per_shard (int): size per shard in MB. Default: 1024. This value is only used when shard is set to True.
|
||||
"""
|
||||
|
||||
if shard:
|
||||
self.save_sharded_optimizer(optimizer, checkpoint, gather_dtensor, prefix, size_per_shard)
|
||||
else:
|
||||
|
|
|
@ -28,6 +28,7 @@ from .utils import (
|
|||
shard_model_checkpoint,
|
||||
shard_optimizer_checkpoint,
|
||||
sharded_optimizer_loading_epilogue,
|
||||
unwrap_optimizer,
|
||||
)
|
||||
|
||||
__all__ = ['GeneralCheckpointIO']
|
||||
|
@ -59,7 +60,7 @@ class GeneralCheckpointIO(CheckpointIO):
|
|||
|
||||
# If optimizer is wrapped, unwrap it.
|
||||
if isinstance(optimizer, OptimizerWrapper):
|
||||
optimizer = optimizer.optim
|
||||
optimizer = unwrap_optimizer(optimizer)
|
||||
|
||||
# Read checkpoint index file.
|
||||
ckpt_index_file = CheckpointIndexFile.from_file(index_file_path)
|
||||
|
@ -96,6 +97,11 @@ class GeneralCheckpointIO(CheckpointIO):
|
|||
- A group file (pytorch_optim_group.bin) recording information of param_groups
|
||||
- Multiple files (pytorch_optim-000XX.bin) that store state tensors of optimizer in a sharding way
|
||||
"""
|
||||
|
||||
# If optimizer is wrapped, unwrap it.
|
||||
if isinstance(optimizer, OptimizerWrapper):
|
||||
optimizer = unwrap_optimizer(optimizer)
|
||||
|
||||
if os.path.isfile(checkpoint):
|
||||
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
|
||||
return
|
||||
|
@ -121,9 +127,8 @@ class GeneralCheckpointIO(CheckpointIO):
|
|||
shard, current_size = shard_pair
|
||||
shard_file = get_shard_filename(states_name, idx)
|
||||
total_size = total_size + current_size
|
||||
for param_id in shard.keys():
|
||||
index_file.append_weight_map(str(param_id), shard_file)
|
||||
|
||||
for key in shard.keys():
|
||||
index_file.append_weight_map(key, shard_file)
|
||||
checkpoint_file_path = os.path.join(checkpoint, shard_file)
|
||||
save_state_dict(shard, checkpoint_file_path, use_safetensors=False)
|
||||
|
||||
|
@ -177,7 +182,6 @@ class GeneralCheckpointIO(CheckpointIO):
|
|||
total_size = total_size + shard_pair[1]
|
||||
for key in shard.keys():
|
||||
index_file.append_weight_map(key, shard_file)
|
||||
|
||||
checkpoint_file_path = os.path.join(checkpoint_path, shard_file)
|
||||
save_state_dict(shard, checkpoint_file_path, use_safetensors)
|
||||
|
||||
|
|
|
@ -10,6 +10,8 @@ import torch
|
|||
import torch.nn as nn
|
||||
from torch.optim import Optimizer
|
||||
|
||||
from colossalai.interface import OptimizerWrapper
|
||||
from colossalai.nn.optimizer import ColossalaiOptimizer
|
||||
from colossalai.tensor.d_tensor import is_distributed_tensor
|
||||
|
||||
SAFE_WEIGHTS_NAME = "model.safetensors"
|
||||
|
@ -88,6 +90,19 @@ def is_safetensor_checkpoint(checkpoint_file_path: str) -> bool:
|
|||
# ======================================
|
||||
# Helper functions for saving shard file
|
||||
# ======================================
|
||||
def unwrap_optimizer(optimizer: OptimizerWrapper):
|
||||
'''
|
||||
Unwrap a wrapped optimizer.
|
||||
This method should be used before saving/loading it to/from sharded checkpoints.
|
||||
'''
|
||||
|
||||
# TODO(Baizhou): ColossalaiOptimizer will be replaced with OptimizerWrapper in the future
|
||||
unwrapped_optim = optimizer.optim
|
||||
if isinstance(unwrapped_optim, ColossalaiOptimizer):
|
||||
unwrapped_optim = unwrapped_optim.optim
|
||||
return unwrapped_optim
|
||||
|
||||
|
||||
def shard_model_checkpoint(state_dict: torch.Tensor, max_shard_size: int = 1024) -> Iterator[Tuple[OrderedDict, int]]:
|
||||
"""
|
||||
Splits a model state dictionary in sub-checkpoints so that the final size of each sub-checkpoint does not exceed a
|
||||
|
@ -103,7 +118,7 @@ def shard_model_checkpoint(state_dict: torch.Tensor, max_shard_size: int = 1024)
|
|||
weight_size = calculate_tensor_size(weight)
|
||||
|
||||
# If this weight is going to tip up over the maximal size, we split.
|
||||
if current_block_size + weight_size > max_shard_size:
|
||||
if current_block_size + weight_size > max_shard_size and current_block_size > 0:
|
||||
ret_block = current_block
|
||||
ret_block_size = current_block_size
|
||||
current_block = {}
|
||||
|
@ -140,9 +155,10 @@ def shard_optimizer_checkpoint(state_dict: dict, max_shard_size: int = 1024) ->
|
|||
isDTensor = False
|
||||
for state_tensor in state.values():
|
||||
|
||||
# When state_tensor is None (e.g., a SGD optimizer with momentum set to 0),
|
||||
# When state_tensor is not of Tensor class,
|
||||
# e.g., a SGD optimizer with momentum set to 0 can have None as state
|
||||
# The calculation of tensor size should be skipped to avoid error.
|
||||
if state_tensor is None:
|
||||
if not isinstance(state_tensor, torch.Tensor):
|
||||
continue
|
||||
|
||||
# If the states are stored as DTensors, mark isDTensor as true.
|
||||
|
@ -152,7 +168,7 @@ def shard_optimizer_checkpoint(state_dict: dict, max_shard_size: int = 1024) ->
|
|||
|
||||
if not isDTensor:
|
||||
|
||||
if current_block_size + state_size > max_shard_size:
|
||||
if current_block_size + state_size > max_shard_size and current_block_size > 0:
|
||||
ret_block = current_block
|
||||
ret_block_size = current_block_size
|
||||
current_block = {}
|
||||
|
|
|
@ -119,3 +119,9 @@ class OptimizerWrapper:
|
|||
"""
|
||||
raise NotImplementedError(
|
||||
"The method unscale_grad is only available for optimizers with mixed precision training")
|
||||
|
||||
def unwrap(self):
|
||||
"""
|
||||
Unwrap the optimizer for checkpoint saving/loading.
|
||||
"""
|
||||
return self.optim
|
||||
|
|
|
@ -5,6 +5,7 @@ import torch.distributed as dist
|
|||
from torch import Tensor
|
||||
from torch.distributed import ProcessGroup
|
||||
from torch.testing import assert_close
|
||||
from torch.utils._pytree import tree_flatten
|
||||
|
||||
|
||||
def assert_equal(a: Tensor, b: Tensor):
|
||||
|
@ -16,7 +17,12 @@ def assert_not_equal(a: Tensor, b: Tensor):
|
|||
|
||||
|
||||
def assert_close_loose(a: Tensor, b: Tensor, rtol: float = 1e-3, atol: float = 1e-3):
|
||||
assert_close(a, b, rtol=rtol, atol=atol)
|
||||
assert_close(a,
|
||||
b,
|
||||
rtol=rtol,
|
||||
atol=atol,
|
||||
msg=f"Tensor not close, shape: {a.shape} vs {b.shape}, \
|
||||
dtype: {a.dtype} vs {b.dtype}")
|
||||
|
||||
|
||||
def assert_equal_in_group(tensor: Tensor, process_group: ProcessGroup = None):
|
||||
|
@ -33,25 +39,51 @@ def assert_equal_in_group(tensor: Tensor, process_group: ProcessGroup = None):
|
|||
|
||||
|
||||
def check_state_dict_equal(d1: OrderedDict, d2: OrderedDict, ignore_device: bool = True):
|
||||
for k, v in d1.items():
|
||||
if isinstance(v, dict):
|
||||
check_state_dict_equal(v, d2[k])
|
||||
elif isinstance(v, list):
|
||||
for i in range(len(v)):
|
||||
if isinstance(v[i], torch.Tensor):
|
||||
assert len(list(d1.keys())) == len(list(d2.keys())), \
|
||||
f"Number of keys unequal: {len(list(d1.keys()))} vs {len(list(d2.keys()))}"
|
||||
for k, v1 in d1.items():
|
||||
assert k in d2
|
||||
v2 = d2[k]
|
||||
if isinstance(v1, dict):
|
||||
assert isinstance(v2, dict)
|
||||
check_state_dict_equal(v1, v2, ignore_device)
|
||||
elif isinstance(v1, list):
|
||||
assert isinstance(v2, list)
|
||||
for v1_i, v2_i in zip(v1, v2):
|
||||
if isinstance(v1_i, torch.Tensor):
|
||||
assert isinstance(v2_i, torch.Tensor)
|
||||
if not ignore_device:
|
||||
v[i] = v[i].to("cpu")
|
||||
d2[k][i] = d2[k][i].to("cpu")
|
||||
assert torch.equal(v[i], d2[k][i])
|
||||
v1_i = v1_i.to("cpu")
|
||||
v2_i = v2_i.to("cpu")
|
||||
assert_close_loose(v1_i, v2_i)
|
||||
elif isinstance(v1_i, dict):
|
||||
assert isinstance(v2_i, dict)
|
||||
check_state_dict_equal(v1_i, v2_i, ignore_device)
|
||||
else:
|
||||
assert v[i] == d2[k][i]
|
||||
elif isinstance(v, torch.Tensor):
|
||||
assert v1_i == v2_i, f"{v1_i} not equals to {v2_i}"
|
||||
elif isinstance(v1, torch.Tensor):
|
||||
assert isinstance(v2, torch.Tensor)
|
||||
if not ignore_device:
|
||||
v = v.to("cpu")
|
||||
d2[k] = d2[k].to("cpu")
|
||||
assert torch.equal(v, d2[k])
|
||||
v1 = v1.to("cpu")
|
||||
v2 = v2.to("cpu")
|
||||
assert_close_loose(v1, v2)
|
||||
else:
|
||||
assert v == d2[k]
|
||||
assert v1 == v2, f"{v1} not equals to {v2}"
|
||||
|
||||
|
||||
def check_state_dict_equal_pytree(d1: OrderedDict, d2: OrderedDict, ignore_device: bool = True):
|
||||
flat_d1, _ = tree_flatten(d1)
|
||||
flat_d2, _ = tree_flatten(d2)
|
||||
assert len(flat_d1) == len(flat_d2)
|
||||
for v1, v2 in zip(flat_d1, flat_d2):
|
||||
if isinstance(v1, torch.Tensor):
|
||||
assert isinstance(v2, torch.Tensor)
|
||||
if not ignore_device:
|
||||
v1 = v1.to("cpu")
|
||||
v2 = v2.to("cpu")
|
||||
assert_close_loose(v1, v2)
|
||||
else:
|
||||
assert v1 == v2, f"{v1} not equals to {v2}"
|
||||
|
||||
|
||||
def assert_hf_output_close(out1: Any,
|
||||
|
|
|
@ -1,4 +1,6 @@
|
|||
# this code is inspired by the DeepSpeed library and implemented with our own design from scratch
|
||||
import copy
|
||||
import gc
|
||||
import math
|
||||
import warnings
|
||||
from typing import Any, Dict, Set, Tuple
|
||||
|
@ -101,6 +103,11 @@ class ZeroOptimizer(ColossalaiOptimizer):
|
|||
self.clipping_flag = clipping_norm > 0.0
|
||||
self.max_norm = clipping_norm
|
||||
self.verbose = verbose
|
||||
self.param_groups_backup = list()
|
||||
|
||||
# Mapping from integer id to real/fake param tensor, used for checkpointing.
|
||||
self.id_to_real_params: Dict[int, Parameter] = dict()
|
||||
self.id_to_fake_params: Dict[int, Parameter] = dict()
|
||||
|
||||
if self.clipping_flag:
|
||||
assert norm_type == 2.0, "ZeroOptimizer only supports L2 norm now"
|
||||
|
@ -301,25 +308,352 @@ class ZeroOptimizer(ColossalaiOptimizer):
|
|||
end = min(local_chunk.shard_size, param_info.end - local_chunk.shard_begin)
|
||||
return begin, end
|
||||
|
||||
param_id = -1
|
||||
for group in self.optim.param_groups:
|
||||
fake_params_list = list()
|
||||
|
||||
group_backup = {k: v for k, v in group.items() if k != 'params'}
|
||||
group_ids = []
|
||||
for param in group['params']:
|
||||
|
||||
# Record the mapping of id to current param.
|
||||
param_id += 1
|
||||
self.id_to_real_params[param_id] = param
|
||||
group_ids.append(param_id)
|
||||
|
||||
# If current param is controlled by current process, add it to fake_param.
|
||||
if is_ddp_ignored(param):
|
||||
continue
|
||||
chunk16 = self.chunk_manager.get_chunk(param)
|
||||
range_pair = get_range_pair(chunk16, param)
|
||||
if range_pair[0] >= range_pair[1]:
|
||||
continue
|
||||
|
||||
grad_device = self.module.grads_device[param]
|
||||
fake_param = torch.nn.Parameter(torch.empty([0], device=grad_device))
|
||||
self.param_to_chunk32[fake_param] = chunk16.paired_chunk
|
||||
self.param_to_range[fake_param] = range_pair
|
||||
|
||||
self.id_to_fake_params[param_id] = fake_param
|
||||
fake_params_list.append(fake_param)
|
||||
|
||||
# Update self.optim.param_groups as well as backup group.
|
||||
group['params'] = fake_params_list
|
||||
group_backup['params'] = group_ids
|
||||
self.param_groups_backup.append(group_backup)
|
||||
|
||||
def get_offsets(self, param_id: int) -> tuple:
|
||||
'''
|
||||
Args:
|
||||
param_id(int): The id of parameter.
|
||||
|
||||
Returns:
|
||||
chunk_offset(int): Offset of parameter inside the chunk.
|
||||
shard_offset(int): Offset of its optimizer state shard
|
||||
relative to the whole optimizer state.
|
||||
shard_size(int): Length of parameter shard owned by current process.
|
||||
'''
|
||||
|
||||
if param_id not in self.id_to_fake_params:
|
||||
return -1, -1, -1
|
||||
fake_param = self.id_to_fake_params[param_id]
|
||||
chunk = self.param_to_chunk32[fake_param].paired_chunk
|
||||
param = self.id_to_real_params[param_id]
|
||||
param_info = chunk.tensors_info[param]
|
||||
|
||||
begin_in_chunk, end_in_chunk = self.param_to_range[fake_param]
|
||||
chunk_offset = begin_in_chunk
|
||||
shard_offset = begin_in_chunk + chunk.shard_begin - param_info.offset
|
||||
shard_size = end_in_chunk - begin_in_chunk
|
||||
assert chunk_offset >= 0 and shard_offset >= 0
|
||||
|
||||
return chunk_offset, shard_offset, shard_size
|
||||
|
||||
def collect_states(self, param_id: int, only_rank_0: bool = True) -> dict:
|
||||
"""
|
||||
Args:
|
||||
param_id (int): id of the parameter whose state is to be gathered at master rank.
|
||||
only_rank_0(bool): if True, states will be collected only on master rank, otherwise collected on every rank.
|
||||
|
||||
Returns:
|
||||
collected_states(dict): the gathered optimzier state of parameter with given id
|
||||
if this method is called by master rank, otherwise an empty dict.
|
||||
|
||||
This method can work only when called by all processes simultaneously.
|
||||
"""
|
||||
|
||||
# Get param & chunk & process group.
|
||||
param = self.id_to_real_params[param_id]
|
||||
fake_param = self.id_to_fake_params.get(param_id, None)
|
||||
chunk = self.chunk_manager.get_chunk(param)
|
||||
process_group = chunk.torch_pg
|
||||
rank = dist.get_rank(process_group)
|
||||
master_rank = 0
|
||||
collected_states = {}
|
||||
|
||||
# Fetch names of states through all_gather.
|
||||
local_state_names = None
|
||||
if fake_param is not None:
|
||||
local_state_names = list(self.optim.state[fake_param].keys())
|
||||
gathered_state_names = [None for _ in range(dist.get_world_size(process_group))]
|
||||
dist.barrier()
|
||||
dist.all_gather_object(gathered_state_names, local_state_names)
|
||||
state_names = None
|
||||
for names in gathered_state_names:
|
||||
if names is not None:
|
||||
# Assume different devices share the same set of state names if they have.
|
||||
state_names = copy.deepcopy(names)
|
||||
break
|
||||
|
||||
# Directly return if this parameter doesn't have optimizer states.
|
||||
# e.g. parameter freezed/layer dropped
|
||||
if state_names is None:
|
||||
return collected_states
|
||||
|
||||
# Boolean variable is_collector indicates that whether the current rank
|
||||
# needs to gather the whole optimizer states.
|
||||
# Only master rank is collector when only_rank_0 is True.
|
||||
# Every rank is collector when only_rank_0 is False.
|
||||
is_collector = (rank == master_rank) or (not only_rank_0)
|
||||
|
||||
# If the chunk is kept gathered,
|
||||
# the parameteres are treated the same as that of those in strict DDP during training.
|
||||
# So states can be directly fetched from current device.
|
||||
if chunk.keep_gathered:
|
||||
assert param_id in self.id_to_fake_params
|
||||
if is_collector:
|
||||
states = self.optim.state[fake_param]
|
||||
for state_name in state_names:
|
||||
if state_name == 'step':
|
||||
# To keep aligned with pytorch, state 'step' is stored as a pytorch tensor with type float32.
|
||||
collected_states[state_name] = torch.tensor(states['step'],
|
||||
dtype=torch.float32,
|
||||
requires_grad=False).cpu()
|
||||
else:
|
||||
collected_states[state_name] = states[state_name].detach().clone().to(torch.float32).cpu()
|
||||
return collected_states
|
||||
|
||||
# Check whether the param with given id is managed by current process.
|
||||
own_param = param_id in self.id_to_fake_params
|
||||
|
||||
# Collector gets prepared for state collecting.
|
||||
if is_collector:
|
||||
for state_name in state_names:
|
||||
if state_name == 'step':
|
||||
# To keep aligned with pytorch, state 'step' is stored as a pytorch tensor with type float32.
|
||||
collected_states[state_name] = torch.tensor(0.0, dtype=torch.float32, requires_grad=False).cpu()
|
||||
else:
|
||||
collected_states[state_name] = torch.zeros(param.numel(), dtype=torch.float32,
|
||||
requires_grad=False).cpu()
|
||||
|
||||
# Materials for gathering, including compacted state tensors, and the offset of shard inside each state.
|
||||
compacted_states = self.pack_optimizer_states_to_tensor(param_id, state_names) if own_param else None
|
||||
_, shard_offset, shard_size = self.get_offsets(param_id)
|
||||
|
||||
# Collectors gather state shards through all_gathering.
|
||||
gathered_state_shards = [None for _ in range(dist.get_world_size(process_group))]
|
||||
|
||||
dist.barrier()
|
||||
dist.all_gather_object(gathered_state_shards, [compacted_states, shard_offset, shard_size])
|
||||
|
||||
if is_collector:
|
||||
for state_shard in gathered_state_shards:
|
||||
compacted_states = state_shard[0]
|
||||
shard_offset = state_shard[1]
|
||||
shard_size = state_shard[2]
|
||||
if compacted_states is None:
|
||||
continue
|
||||
self.load_from_compacted_states(compacted_states, collected_states, state_names, shard_offset,
|
||||
shard_size)
|
||||
|
||||
# Clean gathered states
|
||||
for state_shard in gathered_state_shards:
|
||||
del state_shard[0]
|
||||
gc.collect()
|
||||
|
||||
# Reshape tensors
|
||||
if is_collector:
|
||||
for state_name, state_tensor in collected_states.items():
|
||||
if state_tensor.numel() == param.numel():
|
||||
collected_states[state_name] = torch.reshape(state_tensor, param.shape)
|
||||
|
||||
return collected_states
|
||||
|
||||
def pack_optimizer_states_to_tensor(self,
|
||||
param_id: int,
|
||||
state_names: list,
|
||||
device: torch.device = torch.device('cuda'),
|
||||
dtype: torch.dtype = torch.float32) -> torch.Tensor:
|
||||
'''
|
||||
With param id given, pack its optimizer states into a compact tensor and return.
|
||||
'''
|
||||
if param_id not in self.id_to_fake_params:
|
||||
return None
|
||||
|
||||
fake_param = self.id_to_fake_params[param_id]
|
||||
param_range = self.param_to_range[fake_param]
|
||||
states = self.optim.state[fake_param]
|
||||
shard_size = param_range[1] - param_range[0]
|
||||
compacted_size = 0
|
||||
for name in state_names:
|
||||
if name == 'step':
|
||||
compacted_size += 1
|
||||
else:
|
||||
compacted_size += shard_size
|
||||
compacted_states = torch.zeros(compacted_size, dtype=dtype, device=device, requires_grad=False)
|
||||
|
||||
next_state_offset = 0
|
||||
for state_name, state_tensor in states.items():
|
||||
# State 'step' needs special operation.
|
||||
if state_name == 'step':
|
||||
if isinstance(state_tensor, torch.Tensor):
|
||||
compacted_states[next_state_offset] = state_tensor[0].item()
|
||||
else:
|
||||
assert isinstance(state_tensor, int)
|
||||
compacted_states[next_state_offset] = state_tensor
|
||||
next_state_offset += 1
|
||||
else:
|
||||
assert state_tensor.numel() == shard_size
|
||||
compacted_states[next_state_offset:next_state_offset + shard_size].copy_(state_tensor)
|
||||
next_state_offset += shard_size
|
||||
|
||||
return compacted_states
|
||||
|
||||
def load_from_compacted_states(self, compacted_states: torch.Tensor, collected_states: dict, state_names: list,
|
||||
shard_start: int, shard_size: int):
|
||||
'''
|
||||
Given a tensor carrying compacted optimizer states,
|
||||
update these states to collected_states.
|
||||
'''
|
||||
shard_end = shard_start + shard_size
|
||||
next_state_offset = 0
|
||||
|
||||
for state_name in state_names:
|
||||
if state_name == 'step':
|
||||
collected_states['step'].data = torch.tensor(compacted_states[next_state_offset].item(),
|
||||
dtype=torch.float32,
|
||||
requires_grad=False).cpu()
|
||||
next_state_offset += 1
|
||||
else:
|
||||
target_segment = collected_states[state_name][shard_start:shard_end]
|
||||
target_segment.copy_(compacted_states[next_state_offset:next_state_offset + shard_size])
|
||||
next_state_offset += shard_size
|
||||
|
||||
def state_dict(self, only_rank_0: bool = True) -> dict:
|
||||
"""
|
||||
Args:
|
||||
only_rank_0 (bool): a boolean value indicating whether the state_dict is collected
|
||||
only on rank 0, dafault to True.
|
||||
|
||||
Returns:
|
||||
The complete state of the optimizer as a :class:`dict`.
|
||||
It contains two entries:
|
||||
|
||||
* state - a dict holding current optimization state. Its content
|
||||
differs between optimizer classes.
|
||||
* param_groups - a list containing all parameter groups where each
|
||||
parameter group is a dict.
|
||||
|
||||
Warning: This method will gather and return the whole optimizer state_dict,
|
||||
so it should be called only when memory resources are abundant.
|
||||
"""
|
||||
state_dict = {}
|
||||
state_dict['param_groups'] = copy.deepcopy(self.param_groups_backup)
|
||||
|
||||
torch_special_hyperparameters = {
|
||||
'amsgrad': False,
|
||||
'maximize': False,
|
||||
'foreach': None,
|
||||
'capturable': False,
|
||||
'differentiable': False,
|
||||
'fused': False
|
||||
}
|
||||
|
||||
for group in state_dict['param_groups']:
|
||||
for k, v in torch_special_hyperparameters.items():
|
||||
if k not in group:
|
||||
group[k] = v
|
||||
|
||||
# Collect optimizer states.
|
||||
state_dict['state'] = dict()
|
||||
for param_id in self.id_to_real_params.keys():
|
||||
dist.barrier()
|
||||
state_dict['state'][param_id] = self.collect_states(param_id=param_id, only_rank_0=only_rank_0)
|
||||
return state_dict
|
||||
|
||||
def load_param_groups(self, saved_param_groups: list):
|
||||
"""
|
||||
Load saved_param_groups into
|
||||
self.param_groups and self.param_groups_backup
|
||||
"""
|
||||
self.param_groups_backup = copy.deepcopy(saved_param_groups)
|
||||
|
||||
# discard the older param_groups
|
||||
self.optim.param_groups = []
|
||||
|
||||
for group in saved_param_groups:
|
||||
fake_params_list = list()
|
||||
updated_group = {k: v for k, v in group.items() if k != 'params'}
|
||||
for param_id in group['params']:
|
||||
if param_id not in self.id_to_fake_params:
|
||||
continue
|
||||
fake_param = self.id_to_fake_params[param_id]
|
||||
fake_params_list.append(fake_param)
|
||||
updated_group['params'] = fake_params_list
|
||||
self.optim.param_groups.append(updated_group)
|
||||
|
||||
def load_single_param_states(self, param_id: int, saved_states: dict):
|
||||
"""
|
||||
Load saved optimizer states into parameter with given id.
|
||||
"""
|
||||
|
||||
def cast(param, state_range, value, key=None):
|
||||
"""
|
||||
Make a copy of the needed segment of value and cast it to device of param.
|
||||
"""
|
||||
assert isinstance(value, torch.Tensor)
|
||||
ret_val = value
|
||||
if (key == "step"):
|
||||
assert value.numel() == 1
|
||||
ret_val = int(value.item())
|
||||
else:
|
||||
state_start, state_end = state_range
|
||||
ret_val = torch.zeros(state_end - state_start,
|
||||
dtype=torch.float32,
|
||||
device=param.device,
|
||||
requires_grad=False)
|
||||
ret_val.copy_(value.flatten()[state_start:state_end])
|
||||
return ret_val
|
||||
|
||||
assert param_id in self.id_to_fake_params
|
||||
fake_param = self.id_to_fake_params[param_id]
|
||||
_, state_offset, param_size = self.get_offsets(param_id)
|
||||
state_range = (state_offset, state_offset + param_size)
|
||||
|
||||
# Copy states assigned to param (and cast tensors to appropriate types).
|
||||
updated_states = dict()
|
||||
for k, v in saved_states.items():
|
||||
updated_states[k] = cast(fake_param, state_range, v, k)
|
||||
del v # clean loaded states
|
||||
self.optim.state[fake_param].update(updated_states)
|
||||
|
||||
def load_state_dict(self, state_dict: dict):
|
||||
"""Loads optimizer state from whole optimizer state_dict.
|
||||
During loading, filter out the part of states not considered by current process.
|
||||
|
||||
Args:
|
||||
state_dict (dict): optimizer state. Should be an object returned
|
||||
from a call to :meth:`state_dict`.
|
||||
"""
|
||||
assert 'param_groups' in state_dict
|
||||
self.load_param_groups(state_dict['param_groups'])
|
||||
|
||||
state = state_dict['state']
|
||||
|
||||
for param_id, param_states in state.items():
|
||||
if param_id in self.id_to_fake_params:
|
||||
self.load_single_param_states(param_id, param_states)
|
||||
|
||||
# Epilogue for pytorch optimizer.
|
||||
self.optim._hook_for_profile() # To support multiprocessing pickle/unpickle.
|
||||
self.optim.defaults.setdefault('differentiable', False)
|
||||
|
||||
|
||||
class GeminiAdamOptimizer(ZeroOptimizer):
|
||||
|
|
|
@ -8,15 +8,18 @@ from utils import shared_tempdir
|
|||
import colossalai
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.booster.plugin import GeminiPlugin
|
||||
from colossalai.booster.plugin.gemini_plugin import GeminiCheckpointIO
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.testing import check_state_dict_equal, parameterize, rerun_if_address_is_in_use, spawn
|
||||
from colossalai.zero import ZeroDDP
|
||||
from colossalai.zero.gemini.chunk import ChunkManager, search_chunk_configuration
|
||||
from colossalai.zero.gemini.gemini_mgr import GeminiManager
|
||||
from colossalai.testing import (
|
||||
check_state_dict_equal,
|
||||
clear_cache_before_run,
|
||||
parameterize,
|
||||
rerun_if_address_is_in_use,
|
||||
spawn,
|
||||
)
|
||||
from tests.kit.model_zoo import model_zoo
|
||||
|
||||
|
||||
@clear_cache_before_run()
|
||||
@parameterize('placement_policy', ['cuda', 'cpu'])
|
||||
@parameterize('model_name', ['transformers_bert_for_sequence_classification'])
|
||||
@parameterize('use_safetensors', [False, True])
|
||||
|
@ -29,33 +32,33 @@ def exam_state_dict_with_origin(placement_policy, model_name, use_safetensors: b
|
|||
pretrained_path = os.path.join(tempdir, 'pretrained')
|
||||
bert_model.config.save_pretrained(save_directory=pretrained_path)
|
||||
|
||||
# TODO(ver217): use boost api
|
||||
config_dict, *_ = search_chunk_configuration(bert_model, search_range_m=1, search_interval=100)
|
||||
chunk_manager = ChunkManager(config_dict)
|
||||
gemini_manager = GeminiManager(placement_policy, chunk_manager)
|
||||
bert_model = ZeroDDP(bert_model, gemini_manager)
|
||||
bert_model.train()
|
||||
|
||||
ckpt_io = GeminiCheckpointIO()
|
||||
plugin = GeminiPlugin(placement_policy=placement_policy)
|
||||
booster = Booster(plugin=plugin)
|
||||
bert_model, _, _, _, _ = booster.boost(bert_model)
|
||||
model_size = sum(p.numel() * p.element_size() for p in bert_model.parameters()) / 1024**2
|
||||
ckpt_io.save_model(bert_model, (pretrained_path),
|
||||
|
||||
booster.save_model(bert_model,
|
||||
pretrained_path,
|
||||
True,
|
||||
True,
|
||||
'', (model_size / 3),
|
||||
use_safetensors=use_safetensors)
|
||||
dist.barrier()
|
||||
|
||||
new_bert_model = BertForSequenceClassification.from_pretrained(pretrained_path)
|
||||
check_state_dict_equal(bert_model.state_dict(only_rank_0=False, dtype=torch.float32),
|
||||
check_state_dict_equal(bert_model.unwrap().state_dict(only_rank_0=False, dtype=torch.float32),
|
||||
new_bert_model.state_dict(), False)
|
||||
|
||||
|
||||
@clear_cache_before_run()
|
||||
@parameterize('placement_policy', ['cuda', 'cpu'])
|
||||
@parameterize('shard', [True, False])
|
||||
@parameterize('shard', [False])
|
||||
@parameterize('model_name', ['transformers_gpt'])
|
||||
def exam_state_dict(placement_policy, shard: bool, model_name: str):
|
||||
@parameterize('size_per_shard', [32])
|
||||
def exam_state_dict(placement_policy, shard: bool, model_name: str, size_per_shard: int):
|
||||
(model_fn, data_gen_fn, output_transform_fn, _, _) = next(iter(model_zoo.get_sub_registry(model_name).values()))
|
||||
criterion = lambda x: x.mean()
|
||||
plugin = GeminiPlugin(placement_policy=placement_policy)
|
||||
plugin = GeminiPlugin(placement_policy=placement_policy, precision="fp16", initial_scale=(2**14))
|
||||
booster = Booster(plugin=plugin)
|
||||
|
||||
model = model_fn()
|
||||
|
@ -78,18 +81,32 @@ def exam_state_dict(placement_policy, shard: bool, model_name: str):
|
|||
with shared_tempdir() as tempdir:
|
||||
model_ckpt_path = f"{tempdir}/model"
|
||||
optimizer_ckpt_path = f"{tempdir}/optimizer"
|
||||
booster.save_model(model, model_ckpt_path)
|
||||
if not shard:
|
||||
# TODO(ver217): optimizer checkpointing is not supported for sharded checkpoint
|
||||
booster.save_optimizer(optimizer, optimizer_ckpt_path)
|
||||
booster.save_model(model, model_ckpt_path, shard=shard, size_per_shard=size_per_shard)
|
||||
|
||||
booster.save_optimizer(optimizer, optimizer_ckpt_path, shard=shard, size_per_shard=size_per_shard)
|
||||
dist.barrier()
|
||||
|
||||
booster.load_model(new_model, model_ckpt_path)
|
||||
check_state_dict_equal(model.unwrap().state_dict(only_rank_0=False),
|
||||
new_model.unwrap().state_dict(only_rank_0=False), False)
|
||||
if not shard:
|
||||
booster.load_optimizer(new_optimizer, optimizer_ckpt_path)
|
||||
check_state_dict_equal(optimizer.state_dict(), new_optimizer.state_dict(), False)
|
||||
|
||||
booster.load_optimizer(new_optimizer, optimizer_ckpt_path)
|
||||
check_state_dict_equal(optimizer.unwrap().state_dict(only_rank_0=False),
|
||||
new_optimizer.unwrap().state_dict(only_rank_0=False), False)
|
||||
|
||||
# Check the new model/optimizer can successfully run.
|
||||
data = data_gen_fn()
|
||||
data = {
|
||||
k: v.to('cuda') if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__ else v for k, v in data.items()
|
||||
}
|
||||
output = new_model(**data)
|
||||
output = output_transform_fn(output)
|
||||
output_key = list(output.keys())[0]
|
||||
loss = criterion(output[output_key])
|
||||
booster.backward(loss, new_optimizer)
|
||||
new_optimizer.step()
|
||||
booster.save_model(new_model, model_ckpt_path, shard=shard)
|
||||
booster.save_optimizer(new_optimizer, optimizer_ckpt_path, shard=shard)
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
|
@ -100,7 +117,7 @@ def run_dist(rank, world_size, port):
|
|||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize('world_size', [2])
|
||||
@pytest.mark.parametrize('world_size', [1, 2])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_gemini_ckpIO(world_size):
|
||||
spawn(run_dist, world_size)
|
||||
|
|
|
@ -0,0 +1,171 @@
|
|||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.optim import Adam
|
||||
from utils import shared_tempdir
|
||||
|
||||
import colossalai
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.booster.plugin import GeminiPlugin, TorchDDPPlugin
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.testing import (
|
||||
check_state_dict_equal,
|
||||
clear_cache_before_run,
|
||||
parameterize,
|
||||
rerun_if_address_is_in_use,
|
||||
spawn,
|
||||
)
|
||||
from tests.kit.model_zoo import model_zoo
|
||||
|
||||
|
||||
@clear_cache_before_run()
|
||||
@parameterize('shard', [False])
|
||||
@parameterize('model_name', ['transformers_gpt'])
|
||||
def exam_torch_load_from_gemini(shard: bool, model_name: str):
|
||||
|
||||
(model_fn, data_gen_fn, output_transform_fn, _, _) = next(iter(model_zoo.get_sub_registry(model_name).values()))
|
||||
criterion = lambda x: x.mean()
|
||||
plugin = GeminiPlugin(precision="fp16", initial_scale=(2**14))
|
||||
booster = Booster(plugin=plugin)
|
||||
|
||||
model = model_fn()
|
||||
optimizer = HybridAdam(model.parameters(), lr=0.001)
|
||||
model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion)
|
||||
|
||||
data = data_gen_fn()
|
||||
data = {k: v.to('cuda') if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__ else v for k, v in data.items()}
|
||||
output = model(**data)
|
||||
output = output_transform_fn(output)
|
||||
output_key = list(output.keys())[0]
|
||||
loss = criterion(output[output_key])
|
||||
|
||||
booster.backward(loss, optimizer)
|
||||
optimizer.step()
|
||||
|
||||
with shared_tempdir() as tempdir:
|
||||
model_ckpt_path = f"{tempdir}/model"
|
||||
optimizer_ckpt_path = f"{tempdir}/optimizer"
|
||||
|
||||
booster.save_model(model, model_ckpt_path, shard=shard)
|
||||
booster.save_optimizer(optimizer, optimizer_ckpt_path, shard=shard)
|
||||
dist.barrier()
|
||||
|
||||
new_model = model_fn()
|
||||
new_optimizer = Adam(new_model.parameters(), lr=0.001)
|
||||
new_plugin = TorchDDPPlugin()
|
||||
new_booster = Booster(plugin=new_plugin)
|
||||
new_model, new_optimizer, criterion, _, _ = new_booster.boost(new_model, new_optimizer, criterion)
|
||||
|
||||
# Loading HybridAdam states to torch.Adam
|
||||
new_booster.load_model(new_model, model_ckpt_path, strict=True)
|
||||
|
||||
# Add prefix to get aligned with pytorch parameter names.
|
||||
check_state_dict_equal(
|
||||
model.unwrap().state_dict(only_rank_0=False, prefix='module.module.', dtype=torch.float32),
|
||||
new_model.state_dict(), False)
|
||||
|
||||
new_booster.load_optimizer(new_optimizer, optimizer_ckpt_path)
|
||||
check_state_dict_equal(optimizer.unwrap().state_dict(only_rank_0=False), new_optimizer.state_dict(), False)
|
||||
|
||||
# Check the new model/optimizer can successfully run.
|
||||
data = data_gen_fn()
|
||||
data = {
|
||||
k: v.to('cuda') if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__ else v for k, v in data.items()
|
||||
}
|
||||
output = new_model(**data)
|
||||
output = output_transform_fn(output)
|
||||
output_key = list(output.keys())[0]
|
||||
loss = criterion(output[output_key])
|
||||
new_booster.backward(loss, new_optimizer)
|
||||
new_optimizer.step()
|
||||
new_booster.save_model(new_model, model_ckpt_path, shard=shard)
|
||||
new_booster.save_optimizer(new_optimizer, optimizer_ckpt_path, shard=shard)
|
||||
|
||||
|
||||
@clear_cache_before_run()
|
||||
@parameterize('shard', [False])
|
||||
@parameterize('model_name', ['transformers_gpt'])
|
||||
def exam_gemini_load_from_torch(shard: bool, model_name: str):
|
||||
|
||||
(model_fn, data_gen_fn, output_transform_fn, _, _) = next(iter(model_zoo.get_sub_registry(model_name).values()))
|
||||
criterion = lambda x: x.mean()
|
||||
plugin = TorchDDPPlugin()
|
||||
booster = Booster(plugin=plugin)
|
||||
|
||||
model = model_fn()
|
||||
optimizer = Adam(model.parameters(), lr=0.001)
|
||||
model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion)
|
||||
|
||||
data = data_gen_fn()
|
||||
data = {k: v.to('cuda') if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__ else v for k, v in data.items()}
|
||||
output = model(**data)
|
||||
output = output_transform_fn(output)
|
||||
output_key = list(output.keys())[0]
|
||||
loss = criterion(output[output_key])
|
||||
|
||||
booster.backward(loss, optimizer)
|
||||
optimizer.step()
|
||||
|
||||
with shared_tempdir() as tempdir:
|
||||
model_ckpt_path = f"{tempdir}/model"
|
||||
optimizer_ckpt_path = f"{tempdir}/optimizer"
|
||||
|
||||
booster.save_model(model, model_ckpt_path, shard=shard)
|
||||
booster.save_optimizer(optimizer, optimizer_ckpt_path, shard=shard)
|
||||
dist.barrier()
|
||||
|
||||
new_model = model_fn()
|
||||
new_optimizer = HybridAdam(new_model.parameters(), lr=0.001)
|
||||
new_plugin = GeminiPlugin()
|
||||
new_booster = Booster(plugin=new_plugin)
|
||||
new_model, new_optimizer, criterion, _, _ = new_booster.boost(new_model, new_optimizer, criterion)
|
||||
|
||||
# Loading torch.Adam states to HybridAdam
|
||||
new_booster.load_model(new_model, model_ckpt_path, strict=True)
|
||||
|
||||
# Add prefix to get aligned with pytorch parameter names.
|
||||
check_state_dict_equal(
|
||||
new_model.unwrap().state_dict(only_rank_0=False, prefix='module.module.', dtype=torch.float32),
|
||||
model.state_dict(), False)
|
||||
|
||||
new_booster.load_optimizer(new_optimizer, optimizer_ckpt_path)
|
||||
old_state_dict = optimizer.state_dict()
|
||||
new_state_dict = new_optimizer.unwrap().state_dict(only_rank_0=False)
|
||||
|
||||
# Comparison of param_groups needs special care here,
|
||||
# since not all hyperparameters in Adam are used by HybridAdam
|
||||
hyperparameters_to_examine = ['params', 'lr', 'betas', 'eps', 'weight_decay']
|
||||
for old_group, new_group in zip(old_state_dict['param_groups'], new_state_dict['param_groups']):
|
||||
for k in hyperparameters_to_examine:
|
||||
assert k in old_group and k in new_group, \
|
||||
f"Old group's keys: {list(old_group.keys())}, New group's keys: {list(new_group.keys())}"
|
||||
assert old_group[k] == new_group[k]
|
||||
check_state_dict_equal(old_state_dict['state'], new_state_dict['state'], False)
|
||||
|
||||
# Check the new model/optimizer can successfully run.
|
||||
data = data_gen_fn()
|
||||
data = {
|
||||
k: v.to('cuda') if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__ else v for k, v in data.items()
|
||||
}
|
||||
output = new_model(**data)
|
||||
output = output_transform_fn(output)
|
||||
output_key = list(output.keys())[0]
|
||||
loss = criterion(output[output_key])
|
||||
new_booster.backward(loss, new_optimizer)
|
||||
new_optimizer.step()
|
||||
new_booster.save_model(new_model, model_ckpt_path, shard=shard)
|
||||
new_booster.save_optimizer(new_optimizer, optimizer_ckpt_path, shard=shard)
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
config = {}
|
||||
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
exam_torch_load_from_gemini()
|
||||
exam_gemini_load_from_torch()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize('world_size', [1, 2])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_gemini_ckpIO(world_size):
|
||||
spawn(run_dist, world_size)
|
Loading…
Reference in New Issue