Next commit [checkpointio] Unsharded Optimizer Checkpoint for Gemini Plugin (#4141)

* [checkpointio] unsharded optimizer checkpoint for Gemini plugin

* [checkpointio] unsharded optimizer checkpoint for Gemini using all_gather
pull/4217/head
Baizhou Zhang 2023-07-07 16:33:06 +08:00 committed by GitHub
parent fee32a3b78
commit 58913441a1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 684 additions and 83 deletions

View File

@ -33,44 +33,40 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
super().__init__()
self.coordinator = DistCoordinator()
def load_unsharded_model(self, model: GeminiDDP, checkpoint: str, strict: bool = True):
"""
Load model from checkpoint with automatic unwrapping.
"""
# the model should be unwrapped in self.load_model via ModelWrapper.unwrap
return super().load_unsharded_model(model, checkpoint, strict=strict)
def save_unsharded_model(self, model: GeminiDDP, checkpoint: str, gather_dtensor: bool, use_safetensors: bool):
"""
Save model to checkpoint but only on master process.
Save sharded model to checkpoint but only on master process.
The model should be unwrapped in self.load_model via ModelWrapper.unwrap.
As there is communication when getting state dict, this must be called on all processes.
"""
# the model should be unwrapped in self.load_model via ModelWrapper.unwrap
# as there is communication when get state dict, this must be called on all processes
state_dict = model.state_dict(only_rank_0=True)
if self.coordinator.is_master():
save_state_dict(state_dict, checkpoint, use_safetensors)
def load_unsharded_model(self, model: GeminiDDP, checkpoint: str, strict: bool = True):
"""
Load model from checkpoint with automatic unwrapping.
The model should be unwrapped in self.load_model via ModelWrapper.unwrap.
"""
super().load_unsharded_model(model, checkpoint, strict=strict)
def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool):
"""
Save optimizer to checkpoint but only on master process.
"""
# TODO(ver217): optimizer state dict is sharded
warnings.warn('GeminiPlugin does not support save full optimizer checkpoint now. Save it on every process.')
checkpoint = f'{checkpoint}.rank{self.coordinator.rank}'
super().save_unsharded_optimizer(optimizer, checkpoint, gather_dtensor)
def load_optimizer(self, optimizer: Optimizer, checkpoint: str):
warnings.warn(
'GeminiPlugin can only load optimizer checkpoint saved by itself with the same number of processes.')
checkpoint = f'{checkpoint}.rank{self.coordinator.rank}'
super().load_optimizer(optimizer, checkpoint)
def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str):
"""
Save model to checkpoint but only on master process.
Save unsharded optimizer state dict to checkpoint.
After calling optimizer.state_dict(), the complete optimizer states will be collected on master rank.
As there is communication when getting state dict, this must be called on all processes.
The saving process will only be executed by master rank.
"""
state_dict = optimizer.state_dict()
if self.coordinator.is_master():
super().save_lr_scheduler(lr_scheduler, checkpoint)
save_state_dict(state_dict, checkpoint, use_safetensors=False)
def load_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: str):
"""
Loading unsharded optimizer from checkpoint file.
For each process, only loading optimizer states of parameters it controls.
"""
super().load_unsharded_optimizer(optimizer, checkpoint)
def save_sharded_model(self,
model: GeminiDDP,
@ -82,6 +78,12 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
"""
Save sharded model
"""
if os.path.isfile(checkpoint_path):
logging.error(f"Provided path ({checkpoint_path}) should be a directory, not a file")
return
Path(checkpoint_path).mkdir(parents=True, exist_ok=True)
state_dict_shard = model.state_dict_shard(max_shard_size=max_shard_size, only_rank_0=True, dtype=torch.float32)
weights_name, save_index_file = get_model_base_filenames(prefix, use_safetensors)
total_size = 0
@ -117,6 +119,23 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
"""
return super().load_sharded_model(model, checkpoint_index_file, strict, use_safetensors, load_sub_module=False)
def save_sharded_optimizer(self, optimizer: Optimizer, checkpoint: Path, gather_dtensor: bool, prefix: str,
size_per_shard: int):
"""
Save sharded optimizer state dict to checkpoint folder.
As there is communication when getting state dict, this must be called on all processes.
"""
Path(checkpoint).mkdir(parents=True, exist_ok=True)
super().save_sharded_optimizer(optimizer, checkpoint, gather_dtensor, prefix, size_per_shard)
def load_sharded_optimizer(self, optimizer: Optimizer, checkpoint_index_file: Path, prefix: str):
"""
Loading sharded optimizer from checkpoint folder, with index file given.
For each process, only loading optimizer states of parameters it controls.
"""
# TODO(Baizhou): To be implemented.
pass
class GeminiModel(ModelWrapper):
@ -193,7 +212,7 @@ class GeminiPlugin(DPPluginBase):
which will be used when using hybrid CPU optimizer.
This argument is meaningless when `placement_policy` of `GeminiManager` is not "auto".
Defaults to 0.0.
initial_scale (float, optional): Initial scale used by DynamicGradScaler. Defaults to 2**32.
initial_scale (float, optional): Initial scale used by DynamicGradScaler. Defaults to 2**16.
min_scale (float, optional): Min scale used by DynamicGradScaler. Defaults to 1.
growth_factor (float, optional): growth_factor used by DynamicGradScaler. Defaults to 2.
backoff_factor (float, optional): backoff_factor used by DynamicGradScaler. Defaults to 0.5.
@ -219,7 +238,7 @@ class GeminiPlugin(DPPluginBase):
min_chunk_size_m: float = 32,
memstats: Optional[MemStats] = None,
gpu_margin_mem_ratio: float = 0.0,
initial_scale: float = 2**32,
initial_scale: float = 2**16,
min_scale: float = 1,
growth_factor: float = 2,
backoff_factor: float = 0.5,

View File

@ -152,6 +152,7 @@ class CheckpointIO(ABC):
names to compose the keys in state_dict. Defaults to None.
size_per_shard (int, optional): Maximum size of checkpoint shard file in MB. This is useful only when ``shard=True``. Defaults to 1024.
"""
index_file_exists, index_file_path = has_index_file(checkpoint)
if Path(checkpoint).is_dir() and not index_file_exists:
@ -186,6 +187,7 @@ class CheckpointIO(ABC):
prefix (str): prefix for the optimizer checkpoint when shard = True. Default: None.
size_per_shard (int): size per shard in MB. Default: 1024. This value is only used when shard is set to True.
"""
if shard:
self.save_sharded_optimizer(optimizer, checkpoint, gather_dtensor, prefix, size_per_shard)
else:

View File

@ -28,6 +28,7 @@ from .utils import (
shard_model_checkpoint,
shard_optimizer_checkpoint,
sharded_optimizer_loading_epilogue,
unwrap_optimizer,
)
__all__ = ['GeneralCheckpointIO']
@ -59,7 +60,7 @@ class GeneralCheckpointIO(CheckpointIO):
# If optimizer is wrapped, unwrap it.
if isinstance(optimizer, OptimizerWrapper):
optimizer = optimizer.optim
optimizer = unwrap_optimizer(optimizer)
# Read checkpoint index file.
ckpt_index_file = CheckpointIndexFile.from_file(index_file_path)
@ -96,6 +97,11 @@ class GeneralCheckpointIO(CheckpointIO):
- A group file (pytorch_optim_group.bin) recording information of param_groups
- Multiple files (pytorch_optim-000XX.bin) that store state tensors of optimizer in a sharding way
"""
# If optimizer is wrapped, unwrap it.
if isinstance(optimizer, OptimizerWrapper):
optimizer = unwrap_optimizer(optimizer)
if os.path.isfile(checkpoint):
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
return
@ -121,9 +127,8 @@ class GeneralCheckpointIO(CheckpointIO):
shard, current_size = shard_pair
shard_file = get_shard_filename(states_name, idx)
total_size = total_size + current_size
for param_id in shard.keys():
index_file.append_weight_map(str(param_id), shard_file)
for key in shard.keys():
index_file.append_weight_map(key, shard_file)
checkpoint_file_path = os.path.join(checkpoint, shard_file)
save_state_dict(shard, checkpoint_file_path, use_safetensors=False)
@ -177,7 +182,6 @@ class GeneralCheckpointIO(CheckpointIO):
total_size = total_size + shard_pair[1]
for key in shard.keys():
index_file.append_weight_map(key, shard_file)
checkpoint_file_path = os.path.join(checkpoint_path, shard_file)
save_state_dict(shard, checkpoint_file_path, use_safetensors)

View File

@ -10,6 +10,8 @@ import torch
import torch.nn as nn
from torch.optim import Optimizer
from colossalai.interface import OptimizerWrapper
from colossalai.nn.optimizer import ColossalaiOptimizer
from colossalai.tensor.d_tensor import is_distributed_tensor
SAFE_WEIGHTS_NAME = "model.safetensors"
@ -88,6 +90,19 @@ def is_safetensor_checkpoint(checkpoint_file_path: str) -> bool:
# ======================================
# Helper functions for saving shard file
# ======================================
def unwrap_optimizer(optimizer: OptimizerWrapper):
'''
Unwrap a wrapped optimizer.
This method should be used before saving/loading it to/from sharded checkpoints.
'''
# TODO(Baizhou): ColossalaiOptimizer will be replaced with OptimizerWrapper in the future
unwrapped_optim = optimizer.optim
if isinstance(unwrapped_optim, ColossalaiOptimizer):
unwrapped_optim = unwrapped_optim.optim
return unwrapped_optim
def shard_model_checkpoint(state_dict: torch.Tensor, max_shard_size: int = 1024) -> Iterator[Tuple[OrderedDict, int]]:
"""
Splits a model state dictionary in sub-checkpoints so that the final size of each sub-checkpoint does not exceed a
@ -103,7 +118,7 @@ def shard_model_checkpoint(state_dict: torch.Tensor, max_shard_size: int = 1024)
weight_size = calculate_tensor_size(weight)
# If this weight is going to tip up over the maximal size, we split.
if current_block_size + weight_size > max_shard_size:
if current_block_size + weight_size > max_shard_size and current_block_size > 0:
ret_block = current_block
ret_block_size = current_block_size
current_block = {}
@ -140,9 +155,10 @@ def shard_optimizer_checkpoint(state_dict: dict, max_shard_size: int = 1024) ->
isDTensor = False
for state_tensor in state.values():
# When state_tensor is None (e.g., a SGD optimizer with momentum set to 0),
# When state_tensor is not of Tensor class,
# e.g., a SGD optimizer with momentum set to 0 can have None as state
# The calculation of tensor size should be skipped to avoid error.
if state_tensor is None:
if not isinstance(state_tensor, torch.Tensor):
continue
# If the states are stored as DTensors, mark isDTensor as true.
@ -152,7 +168,7 @@ def shard_optimizer_checkpoint(state_dict: dict, max_shard_size: int = 1024) ->
if not isDTensor:
if current_block_size + state_size > max_shard_size:
if current_block_size + state_size > max_shard_size and current_block_size > 0:
ret_block = current_block
ret_block_size = current_block_size
current_block = {}

View File

@ -119,3 +119,9 @@ class OptimizerWrapper:
"""
raise NotImplementedError(
"The method unscale_grad is only available for optimizers with mixed precision training")
def unwrap(self):
"""
Unwrap the optimizer for checkpoint saving/loading.
"""
return self.optim

View File

@ -5,6 +5,7 @@ import torch.distributed as dist
from torch import Tensor
from torch.distributed import ProcessGroup
from torch.testing import assert_close
from torch.utils._pytree import tree_flatten
def assert_equal(a: Tensor, b: Tensor):
@ -16,7 +17,12 @@ def assert_not_equal(a: Tensor, b: Tensor):
def assert_close_loose(a: Tensor, b: Tensor, rtol: float = 1e-3, atol: float = 1e-3):
assert_close(a, b, rtol=rtol, atol=atol)
assert_close(a,
b,
rtol=rtol,
atol=atol,
msg=f"Tensor not close, shape: {a.shape} vs {b.shape}, \
dtype: {a.dtype} vs {b.dtype}")
def assert_equal_in_group(tensor: Tensor, process_group: ProcessGroup = None):
@ -33,25 +39,51 @@ def assert_equal_in_group(tensor: Tensor, process_group: ProcessGroup = None):
def check_state_dict_equal(d1: OrderedDict, d2: OrderedDict, ignore_device: bool = True):
for k, v in d1.items():
if isinstance(v, dict):
check_state_dict_equal(v, d2[k])
elif isinstance(v, list):
for i in range(len(v)):
if isinstance(v[i], torch.Tensor):
assert len(list(d1.keys())) == len(list(d2.keys())), \
f"Number of keys unequal: {len(list(d1.keys()))} vs {len(list(d2.keys()))}"
for k, v1 in d1.items():
assert k in d2
v2 = d2[k]
if isinstance(v1, dict):
assert isinstance(v2, dict)
check_state_dict_equal(v1, v2, ignore_device)
elif isinstance(v1, list):
assert isinstance(v2, list)
for v1_i, v2_i in zip(v1, v2):
if isinstance(v1_i, torch.Tensor):
assert isinstance(v2_i, torch.Tensor)
if not ignore_device:
v[i] = v[i].to("cpu")
d2[k][i] = d2[k][i].to("cpu")
assert torch.equal(v[i], d2[k][i])
v1_i = v1_i.to("cpu")
v2_i = v2_i.to("cpu")
assert_close_loose(v1_i, v2_i)
elif isinstance(v1_i, dict):
assert isinstance(v2_i, dict)
check_state_dict_equal(v1_i, v2_i, ignore_device)
else:
assert v[i] == d2[k][i]
elif isinstance(v, torch.Tensor):
assert v1_i == v2_i, f"{v1_i} not equals to {v2_i}"
elif isinstance(v1, torch.Tensor):
assert isinstance(v2, torch.Tensor)
if not ignore_device:
v = v.to("cpu")
d2[k] = d2[k].to("cpu")
assert torch.equal(v, d2[k])
v1 = v1.to("cpu")
v2 = v2.to("cpu")
assert_close_loose(v1, v2)
else:
assert v == d2[k]
assert v1 == v2, f"{v1} not equals to {v2}"
def check_state_dict_equal_pytree(d1: OrderedDict, d2: OrderedDict, ignore_device: bool = True):
flat_d1, _ = tree_flatten(d1)
flat_d2, _ = tree_flatten(d2)
assert len(flat_d1) == len(flat_d2)
for v1, v2 in zip(flat_d1, flat_d2):
if isinstance(v1, torch.Tensor):
assert isinstance(v2, torch.Tensor)
if not ignore_device:
v1 = v1.to("cpu")
v2 = v2.to("cpu")
assert_close_loose(v1, v2)
else:
assert v1 == v2, f"{v1} not equals to {v2}"
def assert_hf_output_close(out1: Any,

View File

@ -1,4 +1,6 @@
# this code is inspired by the DeepSpeed library and implemented with our own design from scratch
import copy
import gc
import math
import warnings
from typing import Any, Dict, Set, Tuple
@ -101,6 +103,11 @@ class ZeroOptimizer(ColossalaiOptimizer):
self.clipping_flag = clipping_norm > 0.0
self.max_norm = clipping_norm
self.verbose = verbose
self.param_groups_backup = list()
# Mapping from integer id to real/fake param tensor, used for checkpointing.
self.id_to_real_params: Dict[int, Parameter] = dict()
self.id_to_fake_params: Dict[int, Parameter] = dict()
if self.clipping_flag:
assert norm_type == 2.0, "ZeroOptimizer only supports L2 norm now"
@ -301,25 +308,352 @@ class ZeroOptimizer(ColossalaiOptimizer):
end = min(local_chunk.shard_size, param_info.end - local_chunk.shard_begin)
return begin, end
param_id = -1
for group in self.optim.param_groups:
fake_params_list = list()
group_backup = {k: v for k, v in group.items() if k != 'params'}
group_ids = []
for param in group['params']:
# Record the mapping of id to current param.
param_id += 1
self.id_to_real_params[param_id] = param
group_ids.append(param_id)
# If current param is controlled by current process, add it to fake_param.
if is_ddp_ignored(param):
continue
chunk16 = self.chunk_manager.get_chunk(param)
range_pair = get_range_pair(chunk16, param)
if range_pair[0] >= range_pair[1]:
continue
grad_device = self.module.grads_device[param]
fake_param = torch.nn.Parameter(torch.empty([0], device=grad_device))
self.param_to_chunk32[fake_param] = chunk16.paired_chunk
self.param_to_range[fake_param] = range_pair
self.id_to_fake_params[param_id] = fake_param
fake_params_list.append(fake_param)
# Update self.optim.param_groups as well as backup group.
group['params'] = fake_params_list
group_backup['params'] = group_ids
self.param_groups_backup.append(group_backup)
def get_offsets(self, param_id: int) -> tuple:
'''
Args:
param_id(int): The id of parameter.
Returns:
chunk_offset(int): Offset of parameter inside the chunk.
shard_offset(int): Offset of its optimizer state shard
relative to the whole optimizer state.
shard_size(int): Length of parameter shard owned by current process.
'''
if param_id not in self.id_to_fake_params:
return -1, -1, -1
fake_param = self.id_to_fake_params[param_id]
chunk = self.param_to_chunk32[fake_param].paired_chunk
param = self.id_to_real_params[param_id]
param_info = chunk.tensors_info[param]
begin_in_chunk, end_in_chunk = self.param_to_range[fake_param]
chunk_offset = begin_in_chunk
shard_offset = begin_in_chunk + chunk.shard_begin - param_info.offset
shard_size = end_in_chunk - begin_in_chunk
assert chunk_offset >= 0 and shard_offset >= 0
return chunk_offset, shard_offset, shard_size
def collect_states(self, param_id: int, only_rank_0: bool = True) -> dict:
"""
Args:
param_id (int): id of the parameter whose state is to be gathered at master rank.
only_rank_0(bool): if True, states will be collected only on master rank, otherwise collected on every rank.
Returns:
collected_states(dict): the gathered optimzier state of parameter with given id
if this method is called by master rank, otherwise an empty dict.
This method can work only when called by all processes simultaneously.
"""
# Get param & chunk & process group.
param = self.id_to_real_params[param_id]
fake_param = self.id_to_fake_params.get(param_id, None)
chunk = self.chunk_manager.get_chunk(param)
process_group = chunk.torch_pg
rank = dist.get_rank(process_group)
master_rank = 0
collected_states = {}
# Fetch names of states through all_gather.
local_state_names = None
if fake_param is not None:
local_state_names = list(self.optim.state[fake_param].keys())
gathered_state_names = [None for _ in range(dist.get_world_size(process_group))]
dist.barrier()
dist.all_gather_object(gathered_state_names, local_state_names)
state_names = None
for names in gathered_state_names:
if names is not None:
# Assume different devices share the same set of state names if they have.
state_names = copy.deepcopy(names)
break
# Directly return if this parameter doesn't have optimizer states.
# e.g. parameter freezed/layer dropped
if state_names is None:
return collected_states
# Boolean variable is_collector indicates that whether the current rank
# needs to gather the whole optimizer states.
# Only master rank is collector when only_rank_0 is True.
# Every rank is collector when only_rank_0 is False.
is_collector = (rank == master_rank) or (not only_rank_0)
# If the chunk is kept gathered,
# the parameteres are treated the same as that of those in strict DDP during training.
# So states can be directly fetched from current device.
if chunk.keep_gathered:
assert param_id in self.id_to_fake_params
if is_collector:
states = self.optim.state[fake_param]
for state_name in state_names:
if state_name == 'step':
# To keep aligned with pytorch, state 'step' is stored as a pytorch tensor with type float32.
collected_states[state_name] = torch.tensor(states['step'],
dtype=torch.float32,
requires_grad=False).cpu()
else:
collected_states[state_name] = states[state_name].detach().clone().to(torch.float32).cpu()
return collected_states
# Check whether the param with given id is managed by current process.
own_param = param_id in self.id_to_fake_params
# Collector gets prepared for state collecting.
if is_collector:
for state_name in state_names:
if state_name == 'step':
# To keep aligned with pytorch, state 'step' is stored as a pytorch tensor with type float32.
collected_states[state_name] = torch.tensor(0.0, dtype=torch.float32, requires_grad=False).cpu()
else:
collected_states[state_name] = torch.zeros(param.numel(), dtype=torch.float32,
requires_grad=False).cpu()
# Materials for gathering, including compacted state tensors, and the offset of shard inside each state.
compacted_states = self.pack_optimizer_states_to_tensor(param_id, state_names) if own_param else None
_, shard_offset, shard_size = self.get_offsets(param_id)
# Collectors gather state shards through all_gathering.
gathered_state_shards = [None for _ in range(dist.get_world_size(process_group))]
dist.barrier()
dist.all_gather_object(gathered_state_shards, [compacted_states, shard_offset, shard_size])
if is_collector:
for state_shard in gathered_state_shards:
compacted_states = state_shard[0]
shard_offset = state_shard[1]
shard_size = state_shard[2]
if compacted_states is None:
continue
self.load_from_compacted_states(compacted_states, collected_states, state_names, shard_offset,
shard_size)
# Clean gathered states
for state_shard in gathered_state_shards:
del state_shard[0]
gc.collect()
# Reshape tensors
if is_collector:
for state_name, state_tensor in collected_states.items():
if state_tensor.numel() == param.numel():
collected_states[state_name] = torch.reshape(state_tensor, param.shape)
return collected_states
def pack_optimizer_states_to_tensor(self,
param_id: int,
state_names: list,
device: torch.device = torch.device('cuda'),
dtype: torch.dtype = torch.float32) -> torch.Tensor:
'''
With param id given, pack its optimizer states into a compact tensor and return.
'''
if param_id not in self.id_to_fake_params:
return None
fake_param = self.id_to_fake_params[param_id]
param_range = self.param_to_range[fake_param]
states = self.optim.state[fake_param]
shard_size = param_range[1] - param_range[0]
compacted_size = 0
for name in state_names:
if name == 'step':
compacted_size += 1
else:
compacted_size += shard_size
compacted_states = torch.zeros(compacted_size, dtype=dtype, device=device, requires_grad=False)
next_state_offset = 0
for state_name, state_tensor in states.items():
# State 'step' needs special operation.
if state_name == 'step':
if isinstance(state_tensor, torch.Tensor):
compacted_states[next_state_offset] = state_tensor[0].item()
else:
assert isinstance(state_tensor, int)
compacted_states[next_state_offset] = state_tensor
next_state_offset += 1
else:
assert state_tensor.numel() == shard_size
compacted_states[next_state_offset:next_state_offset + shard_size].copy_(state_tensor)
next_state_offset += shard_size
return compacted_states
def load_from_compacted_states(self, compacted_states: torch.Tensor, collected_states: dict, state_names: list,
shard_start: int, shard_size: int):
'''
Given a tensor carrying compacted optimizer states,
update these states to collected_states.
'''
shard_end = shard_start + shard_size
next_state_offset = 0
for state_name in state_names:
if state_name == 'step':
collected_states['step'].data = torch.tensor(compacted_states[next_state_offset].item(),
dtype=torch.float32,
requires_grad=False).cpu()
next_state_offset += 1
else:
target_segment = collected_states[state_name][shard_start:shard_end]
target_segment.copy_(compacted_states[next_state_offset:next_state_offset + shard_size])
next_state_offset += shard_size
def state_dict(self, only_rank_0: bool = True) -> dict:
"""
Args:
only_rank_0 (bool): a boolean value indicating whether the state_dict is collected
only on rank 0, dafault to True.
Returns:
The complete state of the optimizer as a :class:`dict`.
It contains two entries:
* state - a dict holding current optimization state. Its content
differs between optimizer classes.
* param_groups - a list containing all parameter groups where each
parameter group is a dict.
Warning: This method will gather and return the whole optimizer state_dict,
so it should be called only when memory resources are abundant.
"""
state_dict = {}
state_dict['param_groups'] = copy.deepcopy(self.param_groups_backup)
torch_special_hyperparameters = {
'amsgrad': False,
'maximize': False,
'foreach': None,
'capturable': False,
'differentiable': False,
'fused': False
}
for group in state_dict['param_groups']:
for k, v in torch_special_hyperparameters.items():
if k not in group:
group[k] = v
# Collect optimizer states.
state_dict['state'] = dict()
for param_id in self.id_to_real_params.keys():
dist.barrier()
state_dict['state'][param_id] = self.collect_states(param_id=param_id, only_rank_0=only_rank_0)
return state_dict
def load_param_groups(self, saved_param_groups: list):
"""
Load saved_param_groups into
self.param_groups and self.param_groups_backup
"""
self.param_groups_backup = copy.deepcopy(saved_param_groups)
# discard the older param_groups
self.optim.param_groups = []
for group in saved_param_groups:
fake_params_list = list()
updated_group = {k: v for k, v in group.items() if k != 'params'}
for param_id in group['params']:
if param_id not in self.id_to_fake_params:
continue
fake_param = self.id_to_fake_params[param_id]
fake_params_list.append(fake_param)
updated_group['params'] = fake_params_list
self.optim.param_groups.append(updated_group)
def load_single_param_states(self, param_id: int, saved_states: dict):
"""
Load saved optimizer states into parameter with given id.
"""
def cast(param, state_range, value, key=None):
"""
Make a copy of the needed segment of value and cast it to device of param.
"""
assert isinstance(value, torch.Tensor)
ret_val = value
if (key == "step"):
assert value.numel() == 1
ret_val = int(value.item())
else:
state_start, state_end = state_range
ret_val = torch.zeros(state_end - state_start,
dtype=torch.float32,
device=param.device,
requires_grad=False)
ret_val.copy_(value.flatten()[state_start:state_end])
return ret_val
assert param_id in self.id_to_fake_params
fake_param = self.id_to_fake_params[param_id]
_, state_offset, param_size = self.get_offsets(param_id)
state_range = (state_offset, state_offset + param_size)
# Copy states assigned to param (and cast tensors to appropriate types).
updated_states = dict()
for k, v in saved_states.items():
updated_states[k] = cast(fake_param, state_range, v, k)
del v # clean loaded states
self.optim.state[fake_param].update(updated_states)
def load_state_dict(self, state_dict: dict):
"""Loads optimizer state from whole optimizer state_dict.
During loading, filter out the part of states not considered by current process.
Args:
state_dict (dict): optimizer state. Should be an object returned
from a call to :meth:`state_dict`.
"""
assert 'param_groups' in state_dict
self.load_param_groups(state_dict['param_groups'])
state = state_dict['state']
for param_id, param_states in state.items():
if param_id in self.id_to_fake_params:
self.load_single_param_states(param_id, param_states)
# Epilogue for pytorch optimizer.
self.optim._hook_for_profile() # To support multiprocessing pickle/unpickle.
self.optim.defaults.setdefault('differentiable', False)
class GeminiAdamOptimizer(ZeroOptimizer):

View File

@ -8,15 +8,18 @@ from utils import shared_tempdir
import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin
from colossalai.booster.plugin.gemini_plugin import GeminiCheckpointIO
from colossalai.nn.optimizer import HybridAdam
from colossalai.testing import check_state_dict_equal, parameterize, rerun_if_address_is_in_use, spawn
from colossalai.zero import ZeroDDP
from colossalai.zero.gemini.chunk import ChunkManager, search_chunk_configuration
from colossalai.zero.gemini.gemini_mgr import GeminiManager
from colossalai.testing import (
check_state_dict_equal,
clear_cache_before_run,
parameterize,
rerun_if_address_is_in_use,
spawn,
)
from tests.kit.model_zoo import model_zoo
@clear_cache_before_run()
@parameterize('placement_policy', ['cuda', 'cpu'])
@parameterize('model_name', ['transformers_bert_for_sequence_classification'])
@parameterize('use_safetensors', [False, True])
@ -29,33 +32,33 @@ def exam_state_dict_with_origin(placement_policy, model_name, use_safetensors: b
pretrained_path = os.path.join(tempdir, 'pretrained')
bert_model.config.save_pretrained(save_directory=pretrained_path)
# TODO(ver217): use boost api
config_dict, *_ = search_chunk_configuration(bert_model, search_range_m=1, search_interval=100)
chunk_manager = ChunkManager(config_dict)
gemini_manager = GeminiManager(placement_policy, chunk_manager)
bert_model = ZeroDDP(bert_model, gemini_manager)
bert_model.train()
ckpt_io = GeminiCheckpointIO()
plugin = GeminiPlugin(placement_policy=placement_policy)
booster = Booster(plugin=plugin)
bert_model, _, _, _, _ = booster.boost(bert_model)
model_size = sum(p.numel() * p.element_size() for p in bert_model.parameters()) / 1024**2
ckpt_io.save_model(bert_model, (pretrained_path),
booster.save_model(bert_model,
pretrained_path,
True,
True,
'', (model_size / 3),
use_safetensors=use_safetensors)
dist.barrier()
new_bert_model = BertForSequenceClassification.from_pretrained(pretrained_path)
check_state_dict_equal(bert_model.state_dict(only_rank_0=False, dtype=torch.float32),
check_state_dict_equal(bert_model.unwrap().state_dict(only_rank_0=False, dtype=torch.float32),
new_bert_model.state_dict(), False)
@clear_cache_before_run()
@parameterize('placement_policy', ['cuda', 'cpu'])
@parameterize('shard', [True, False])
@parameterize('shard', [False])
@parameterize('model_name', ['transformers_gpt'])
def exam_state_dict(placement_policy, shard: bool, model_name: str):
@parameterize('size_per_shard', [32])
def exam_state_dict(placement_policy, shard: bool, model_name: str, size_per_shard: int):
(model_fn, data_gen_fn, output_transform_fn, _, _) = next(iter(model_zoo.get_sub_registry(model_name).values()))
criterion = lambda x: x.mean()
plugin = GeminiPlugin(placement_policy=placement_policy)
plugin = GeminiPlugin(placement_policy=placement_policy, precision="fp16", initial_scale=(2**14))
booster = Booster(plugin=plugin)
model = model_fn()
@ -78,18 +81,32 @@ def exam_state_dict(placement_policy, shard: bool, model_name: str):
with shared_tempdir() as tempdir:
model_ckpt_path = f"{tempdir}/model"
optimizer_ckpt_path = f"{tempdir}/optimizer"
booster.save_model(model, model_ckpt_path)
if not shard:
# TODO(ver217): optimizer checkpointing is not supported for sharded checkpoint
booster.save_optimizer(optimizer, optimizer_ckpt_path)
booster.save_model(model, model_ckpt_path, shard=shard, size_per_shard=size_per_shard)
booster.save_optimizer(optimizer, optimizer_ckpt_path, shard=shard, size_per_shard=size_per_shard)
dist.barrier()
booster.load_model(new_model, model_ckpt_path)
check_state_dict_equal(model.unwrap().state_dict(only_rank_0=False),
new_model.unwrap().state_dict(only_rank_0=False), False)
if not shard:
booster.load_optimizer(new_optimizer, optimizer_ckpt_path)
check_state_dict_equal(optimizer.state_dict(), new_optimizer.state_dict(), False)
booster.load_optimizer(new_optimizer, optimizer_ckpt_path)
check_state_dict_equal(optimizer.unwrap().state_dict(only_rank_0=False),
new_optimizer.unwrap().state_dict(only_rank_0=False), False)
# Check the new model/optimizer can successfully run.
data = data_gen_fn()
data = {
k: v.to('cuda') if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__ else v for k, v in data.items()
}
output = new_model(**data)
output = output_transform_fn(output)
output_key = list(output.keys())[0]
loss = criterion(output[output_key])
booster.backward(loss, new_optimizer)
new_optimizer.step()
booster.save_model(new_model, model_ckpt_path, shard=shard)
booster.save_optimizer(new_optimizer, optimizer_ckpt_path, shard=shard)
def run_dist(rank, world_size, port):
@ -100,7 +117,7 @@ def run_dist(rank, world_size, port):
@pytest.mark.dist
@pytest.mark.parametrize('world_size', [2])
@pytest.mark.parametrize('world_size', [1, 2])
@rerun_if_address_is_in_use()
def test_gemini_ckpIO(world_size):
spawn(run_dist, world_size)

View File

@ -0,0 +1,171 @@
import pytest
import torch
import torch.distributed as dist
from torch.optim import Adam
from utils import shared_tempdir
import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin, TorchDDPPlugin
from colossalai.nn.optimizer import HybridAdam
from colossalai.testing import (
check_state_dict_equal,
clear_cache_before_run,
parameterize,
rerun_if_address_is_in_use,
spawn,
)
from tests.kit.model_zoo import model_zoo
@clear_cache_before_run()
@parameterize('shard', [False])
@parameterize('model_name', ['transformers_gpt'])
def exam_torch_load_from_gemini(shard: bool, model_name: str):
(model_fn, data_gen_fn, output_transform_fn, _, _) = next(iter(model_zoo.get_sub_registry(model_name).values()))
criterion = lambda x: x.mean()
plugin = GeminiPlugin(precision="fp16", initial_scale=(2**14))
booster = Booster(plugin=plugin)
model = model_fn()
optimizer = HybridAdam(model.parameters(), lr=0.001)
model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion)
data = data_gen_fn()
data = {k: v.to('cuda') if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__ else v for k, v in data.items()}
output = model(**data)
output = output_transform_fn(output)
output_key = list(output.keys())[0]
loss = criterion(output[output_key])
booster.backward(loss, optimizer)
optimizer.step()
with shared_tempdir() as tempdir:
model_ckpt_path = f"{tempdir}/model"
optimizer_ckpt_path = f"{tempdir}/optimizer"
booster.save_model(model, model_ckpt_path, shard=shard)
booster.save_optimizer(optimizer, optimizer_ckpt_path, shard=shard)
dist.barrier()
new_model = model_fn()
new_optimizer = Adam(new_model.parameters(), lr=0.001)
new_plugin = TorchDDPPlugin()
new_booster = Booster(plugin=new_plugin)
new_model, new_optimizer, criterion, _, _ = new_booster.boost(new_model, new_optimizer, criterion)
# Loading HybridAdam states to torch.Adam
new_booster.load_model(new_model, model_ckpt_path, strict=True)
# Add prefix to get aligned with pytorch parameter names.
check_state_dict_equal(
model.unwrap().state_dict(only_rank_0=False, prefix='module.module.', dtype=torch.float32),
new_model.state_dict(), False)
new_booster.load_optimizer(new_optimizer, optimizer_ckpt_path)
check_state_dict_equal(optimizer.unwrap().state_dict(only_rank_0=False), new_optimizer.state_dict(), False)
# Check the new model/optimizer can successfully run.
data = data_gen_fn()
data = {
k: v.to('cuda') if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__ else v for k, v in data.items()
}
output = new_model(**data)
output = output_transform_fn(output)
output_key = list(output.keys())[0]
loss = criterion(output[output_key])
new_booster.backward(loss, new_optimizer)
new_optimizer.step()
new_booster.save_model(new_model, model_ckpt_path, shard=shard)
new_booster.save_optimizer(new_optimizer, optimizer_ckpt_path, shard=shard)
@clear_cache_before_run()
@parameterize('shard', [False])
@parameterize('model_name', ['transformers_gpt'])
def exam_gemini_load_from_torch(shard: bool, model_name: str):
(model_fn, data_gen_fn, output_transform_fn, _, _) = next(iter(model_zoo.get_sub_registry(model_name).values()))
criterion = lambda x: x.mean()
plugin = TorchDDPPlugin()
booster = Booster(plugin=plugin)
model = model_fn()
optimizer = Adam(model.parameters(), lr=0.001)
model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion)
data = data_gen_fn()
data = {k: v.to('cuda') if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__ else v for k, v in data.items()}
output = model(**data)
output = output_transform_fn(output)
output_key = list(output.keys())[0]
loss = criterion(output[output_key])
booster.backward(loss, optimizer)
optimizer.step()
with shared_tempdir() as tempdir:
model_ckpt_path = f"{tempdir}/model"
optimizer_ckpt_path = f"{tempdir}/optimizer"
booster.save_model(model, model_ckpt_path, shard=shard)
booster.save_optimizer(optimizer, optimizer_ckpt_path, shard=shard)
dist.barrier()
new_model = model_fn()
new_optimizer = HybridAdam(new_model.parameters(), lr=0.001)
new_plugin = GeminiPlugin()
new_booster = Booster(plugin=new_plugin)
new_model, new_optimizer, criterion, _, _ = new_booster.boost(new_model, new_optimizer, criterion)
# Loading torch.Adam states to HybridAdam
new_booster.load_model(new_model, model_ckpt_path, strict=True)
# Add prefix to get aligned with pytorch parameter names.
check_state_dict_equal(
new_model.unwrap().state_dict(only_rank_0=False, prefix='module.module.', dtype=torch.float32),
model.state_dict(), False)
new_booster.load_optimizer(new_optimizer, optimizer_ckpt_path)
old_state_dict = optimizer.state_dict()
new_state_dict = new_optimizer.unwrap().state_dict(only_rank_0=False)
# Comparison of param_groups needs special care here,
# since not all hyperparameters in Adam are used by HybridAdam
hyperparameters_to_examine = ['params', 'lr', 'betas', 'eps', 'weight_decay']
for old_group, new_group in zip(old_state_dict['param_groups'], new_state_dict['param_groups']):
for k in hyperparameters_to_examine:
assert k in old_group and k in new_group, \
f"Old group's keys: {list(old_group.keys())}, New group's keys: {list(new_group.keys())}"
assert old_group[k] == new_group[k]
check_state_dict_equal(old_state_dict['state'], new_state_dict['state'], False)
# Check the new model/optimizer can successfully run.
data = data_gen_fn()
data = {
k: v.to('cuda') if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__ else v for k, v in data.items()
}
output = new_model(**data)
output = output_transform_fn(output)
output_key = list(output.keys())[0]
loss = criterion(output[output_key])
new_booster.backward(loss, new_optimizer)
new_optimizer.step()
new_booster.save_model(new_model, model_ckpt_path, shard=shard)
new_booster.save_optimizer(new_optimizer, optimizer_ckpt_path, shard=shard)
def run_dist(rank, world_size, port):
config = {}
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
exam_torch_load_from_gemini()
exam_gemini_load_from_torch()
@pytest.mark.dist
@pytest.mark.parametrize('world_size', [1, 2])
@rerun_if_address_is_in_use()
def test_gemini_ckpIO(world_size):
spawn(run_dist, world_size)