From 58913441a1bd5df3848a4766e2f75a8ae0942121 Mon Sep 17 00:00:00 2001 From: Baizhou Zhang Date: Fri, 7 Jul 2023 16:33:06 +0800 Subject: [PATCH] Next commit [checkpointio] Unsharded Optimizer Checkpoint for Gemini Plugin (#4141) * [checkpointio] unsharded optimizer checkpoint for Gemini plugin * [checkpointio] unsharded optimizer checkpoint for Gemini using all_gather --- colossalai/booster/plugin/gemini_plugin.py | 77 ++-- .../checkpoint_io/checkpoint_io_base.py | 2 + .../checkpoint_io/general_checkpoint_io.py | 14 +- colossalai/checkpoint_io/utils.py | 24 +- colossalai/interface/optimizer.py | 6 + colossalai/testing/comparison.py | 64 +++- colossalai/zero/gemini/gemini_optimizer.py | 340 +++++++++++++++++- .../test_gemini_checkpoint_io.py | 69 ++-- .../test_gemini_torch_compability.py | 171 +++++++++ 9 files changed, 684 insertions(+), 83 deletions(-) create mode 100644 tests/test_checkpoint_io/test_gemini_torch_compability.py diff --git a/colossalai/booster/plugin/gemini_plugin.py b/colossalai/booster/plugin/gemini_plugin.py index 1173589fc..6191f271c 100644 --- a/colossalai/booster/plugin/gemini_plugin.py +++ b/colossalai/booster/plugin/gemini_plugin.py @@ -33,44 +33,40 @@ class GeminiCheckpointIO(GeneralCheckpointIO): super().__init__() self.coordinator = DistCoordinator() - def load_unsharded_model(self, model: GeminiDDP, checkpoint: str, strict: bool = True): - """ - Load model from checkpoint with automatic unwrapping. - """ - # the model should be unwrapped in self.load_model via ModelWrapper.unwrap - return super().load_unsharded_model(model, checkpoint, strict=strict) - def save_unsharded_model(self, model: GeminiDDP, checkpoint: str, gather_dtensor: bool, use_safetensors: bool): """ - Save model to checkpoint but only on master process. + Save sharded model to checkpoint but only on master process. + The model should be unwrapped in self.load_model via ModelWrapper.unwrap. + As there is communication when getting state dict, this must be called on all processes. """ - # the model should be unwrapped in self.load_model via ModelWrapper.unwrap - # as there is communication when get state dict, this must be called on all processes state_dict = model.state_dict(only_rank_0=True) if self.coordinator.is_master(): save_state_dict(state_dict, checkpoint, use_safetensors) + def load_unsharded_model(self, model: GeminiDDP, checkpoint: str, strict: bool = True): + """ + Load model from checkpoint with automatic unwrapping. + The model should be unwrapped in self.load_model via ModelWrapper.unwrap. + """ + super().load_unsharded_model(model, checkpoint, strict=strict) + def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool): """ - Save optimizer to checkpoint but only on master process. - """ - # TODO(ver217): optimizer state dict is sharded - warnings.warn('GeminiPlugin does not support save full optimizer checkpoint now. Save it on every process.') - checkpoint = f'{checkpoint}.rank{self.coordinator.rank}' - super().save_unsharded_optimizer(optimizer, checkpoint, gather_dtensor) - - def load_optimizer(self, optimizer: Optimizer, checkpoint: str): - warnings.warn( - 'GeminiPlugin can only load optimizer checkpoint saved by itself with the same number of processes.') - checkpoint = f'{checkpoint}.rank{self.coordinator.rank}' - super().load_optimizer(optimizer, checkpoint) - - def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str): - """ - Save model to checkpoint but only on master process. + Save unsharded optimizer state dict to checkpoint. + After calling optimizer.state_dict(), the complete optimizer states will be collected on master rank. + As there is communication when getting state dict, this must be called on all processes. + The saving process will only be executed by master rank. """ + state_dict = optimizer.state_dict() if self.coordinator.is_master(): - super().save_lr_scheduler(lr_scheduler, checkpoint) + save_state_dict(state_dict, checkpoint, use_safetensors=False) + + def load_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: str): + """ + Loading unsharded optimizer from checkpoint file. + For each process, only loading optimizer states of parameters it controls. + """ + super().load_unsharded_optimizer(optimizer, checkpoint) def save_sharded_model(self, model: GeminiDDP, @@ -82,6 +78,12 @@ class GeminiCheckpointIO(GeneralCheckpointIO): """ Save sharded model """ + if os.path.isfile(checkpoint_path): + logging.error(f"Provided path ({checkpoint_path}) should be a directory, not a file") + return + + Path(checkpoint_path).mkdir(parents=True, exist_ok=True) + state_dict_shard = model.state_dict_shard(max_shard_size=max_shard_size, only_rank_0=True, dtype=torch.float32) weights_name, save_index_file = get_model_base_filenames(prefix, use_safetensors) total_size = 0 @@ -117,6 +119,23 @@ class GeminiCheckpointIO(GeneralCheckpointIO): """ return super().load_sharded_model(model, checkpoint_index_file, strict, use_safetensors, load_sub_module=False) + def save_sharded_optimizer(self, optimizer: Optimizer, checkpoint: Path, gather_dtensor: bool, prefix: str, + size_per_shard: int): + """ + Save sharded optimizer state dict to checkpoint folder. + As there is communication when getting state dict, this must be called on all processes. + """ + Path(checkpoint).mkdir(parents=True, exist_ok=True) + super().save_sharded_optimizer(optimizer, checkpoint, gather_dtensor, prefix, size_per_shard) + + def load_sharded_optimizer(self, optimizer: Optimizer, checkpoint_index_file: Path, prefix: str): + """ + Loading sharded optimizer from checkpoint folder, with index file given. + For each process, only loading optimizer states of parameters it controls. + """ + # TODO(Baizhou): To be implemented. + pass + class GeminiModel(ModelWrapper): @@ -193,7 +212,7 @@ class GeminiPlugin(DPPluginBase): which will be used when using hybrid CPU optimizer. This argument is meaningless when `placement_policy` of `GeminiManager` is not "auto". Defaults to 0.0. - initial_scale (float, optional): Initial scale used by DynamicGradScaler. Defaults to 2**32. + initial_scale (float, optional): Initial scale used by DynamicGradScaler. Defaults to 2**16. min_scale (float, optional): Min scale used by DynamicGradScaler. Defaults to 1. growth_factor (float, optional): growth_factor used by DynamicGradScaler. Defaults to 2. backoff_factor (float, optional): backoff_factor used by DynamicGradScaler. Defaults to 0.5. @@ -219,7 +238,7 @@ class GeminiPlugin(DPPluginBase): min_chunk_size_m: float = 32, memstats: Optional[MemStats] = None, gpu_margin_mem_ratio: float = 0.0, - initial_scale: float = 2**32, + initial_scale: float = 2**16, min_scale: float = 1, growth_factor: float = 2, backoff_factor: float = 0.5, diff --git a/colossalai/checkpoint_io/checkpoint_io_base.py b/colossalai/checkpoint_io/checkpoint_io_base.py index 8ff9d87c2..baff24e1c 100644 --- a/colossalai/checkpoint_io/checkpoint_io_base.py +++ b/colossalai/checkpoint_io/checkpoint_io_base.py @@ -152,6 +152,7 @@ class CheckpointIO(ABC): names to compose the keys in state_dict. Defaults to None. size_per_shard (int, optional): Maximum size of checkpoint shard file in MB. This is useful only when ``shard=True``. Defaults to 1024. """ + index_file_exists, index_file_path = has_index_file(checkpoint) if Path(checkpoint).is_dir() and not index_file_exists: @@ -186,6 +187,7 @@ class CheckpointIO(ABC): prefix (str): prefix for the optimizer checkpoint when shard = True. Default: None. size_per_shard (int): size per shard in MB. Default: 1024. This value is only used when shard is set to True. """ + if shard: self.save_sharded_optimizer(optimizer, checkpoint, gather_dtensor, prefix, size_per_shard) else: diff --git a/colossalai/checkpoint_io/general_checkpoint_io.py b/colossalai/checkpoint_io/general_checkpoint_io.py index 26cafcada..e1d906694 100644 --- a/colossalai/checkpoint_io/general_checkpoint_io.py +++ b/colossalai/checkpoint_io/general_checkpoint_io.py @@ -28,6 +28,7 @@ from .utils import ( shard_model_checkpoint, shard_optimizer_checkpoint, sharded_optimizer_loading_epilogue, + unwrap_optimizer, ) __all__ = ['GeneralCheckpointIO'] @@ -59,7 +60,7 @@ class GeneralCheckpointIO(CheckpointIO): # If optimizer is wrapped, unwrap it. if isinstance(optimizer, OptimizerWrapper): - optimizer = optimizer.optim + optimizer = unwrap_optimizer(optimizer) # Read checkpoint index file. ckpt_index_file = CheckpointIndexFile.from_file(index_file_path) @@ -96,6 +97,11 @@ class GeneralCheckpointIO(CheckpointIO): - A group file (pytorch_optim_group.bin) recording information of param_groups - Multiple files (pytorch_optim-000XX.bin) that store state tensors of optimizer in a sharding way """ + + # If optimizer is wrapped, unwrap it. + if isinstance(optimizer, OptimizerWrapper): + optimizer = unwrap_optimizer(optimizer) + if os.path.isfile(checkpoint): logging.error(f"Provided path ({checkpoint}) should be a directory, not a file") return @@ -121,9 +127,8 @@ class GeneralCheckpointIO(CheckpointIO): shard, current_size = shard_pair shard_file = get_shard_filename(states_name, idx) total_size = total_size + current_size - for param_id in shard.keys(): - index_file.append_weight_map(str(param_id), shard_file) - + for key in shard.keys(): + index_file.append_weight_map(key, shard_file) checkpoint_file_path = os.path.join(checkpoint, shard_file) save_state_dict(shard, checkpoint_file_path, use_safetensors=False) @@ -177,7 +182,6 @@ class GeneralCheckpointIO(CheckpointIO): total_size = total_size + shard_pair[1] for key in shard.keys(): index_file.append_weight_map(key, shard_file) - checkpoint_file_path = os.path.join(checkpoint_path, shard_file) save_state_dict(shard, checkpoint_file_path, use_safetensors) diff --git a/colossalai/checkpoint_io/utils.py b/colossalai/checkpoint_io/utils.py index 485577b96..19e28c3f7 100644 --- a/colossalai/checkpoint_io/utils.py +++ b/colossalai/checkpoint_io/utils.py @@ -10,6 +10,8 @@ import torch import torch.nn as nn from torch.optim import Optimizer +from colossalai.interface import OptimizerWrapper +from colossalai.nn.optimizer import ColossalaiOptimizer from colossalai.tensor.d_tensor import is_distributed_tensor SAFE_WEIGHTS_NAME = "model.safetensors" @@ -88,6 +90,19 @@ def is_safetensor_checkpoint(checkpoint_file_path: str) -> bool: # ====================================== # Helper functions for saving shard file # ====================================== +def unwrap_optimizer(optimizer: OptimizerWrapper): + ''' + Unwrap a wrapped optimizer. + This method should be used before saving/loading it to/from sharded checkpoints. + ''' + + # TODO(Baizhou): ColossalaiOptimizer will be replaced with OptimizerWrapper in the future + unwrapped_optim = optimizer.optim + if isinstance(unwrapped_optim, ColossalaiOptimizer): + unwrapped_optim = unwrapped_optim.optim + return unwrapped_optim + + def shard_model_checkpoint(state_dict: torch.Tensor, max_shard_size: int = 1024) -> Iterator[Tuple[OrderedDict, int]]: """ Splits a model state dictionary in sub-checkpoints so that the final size of each sub-checkpoint does not exceed a @@ -103,7 +118,7 @@ def shard_model_checkpoint(state_dict: torch.Tensor, max_shard_size: int = 1024) weight_size = calculate_tensor_size(weight) # If this weight is going to tip up over the maximal size, we split. - if current_block_size + weight_size > max_shard_size: + if current_block_size + weight_size > max_shard_size and current_block_size > 0: ret_block = current_block ret_block_size = current_block_size current_block = {} @@ -140,9 +155,10 @@ def shard_optimizer_checkpoint(state_dict: dict, max_shard_size: int = 1024) -> isDTensor = False for state_tensor in state.values(): - # When state_tensor is None (e.g., a SGD optimizer with momentum set to 0), + # When state_tensor is not of Tensor class, + # e.g., a SGD optimizer with momentum set to 0 can have None as state # The calculation of tensor size should be skipped to avoid error. - if state_tensor is None: + if not isinstance(state_tensor, torch.Tensor): continue # If the states are stored as DTensors, mark isDTensor as true. @@ -152,7 +168,7 @@ def shard_optimizer_checkpoint(state_dict: dict, max_shard_size: int = 1024) -> if not isDTensor: - if current_block_size + state_size > max_shard_size: + if current_block_size + state_size > max_shard_size and current_block_size > 0: ret_block = current_block ret_block_size = current_block_size current_block = {} diff --git a/colossalai/interface/optimizer.py b/colossalai/interface/optimizer.py index dd9acab17..0eaf2e1ef 100644 --- a/colossalai/interface/optimizer.py +++ b/colossalai/interface/optimizer.py @@ -119,3 +119,9 @@ class OptimizerWrapper: """ raise NotImplementedError( "The method unscale_grad is only available for optimizers with mixed precision training") + + def unwrap(self): + """ + Unwrap the optimizer for checkpoint saving/loading. + """ + return self.optim diff --git a/colossalai/testing/comparison.py b/colossalai/testing/comparison.py index 5cbfb936b..8d9ec8ab5 100644 --- a/colossalai/testing/comparison.py +++ b/colossalai/testing/comparison.py @@ -5,6 +5,7 @@ import torch.distributed as dist from torch import Tensor from torch.distributed import ProcessGroup from torch.testing import assert_close +from torch.utils._pytree import tree_flatten def assert_equal(a: Tensor, b: Tensor): @@ -16,7 +17,12 @@ def assert_not_equal(a: Tensor, b: Tensor): def assert_close_loose(a: Tensor, b: Tensor, rtol: float = 1e-3, atol: float = 1e-3): - assert_close(a, b, rtol=rtol, atol=atol) + assert_close(a, + b, + rtol=rtol, + atol=atol, + msg=f"Tensor not close, shape: {a.shape} vs {b.shape}, \ + dtype: {a.dtype} vs {b.dtype}") def assert_equal_in_group(tensor: Tensor, process_group: ProcessGroup = None): @@ -33,25 +39,51 @@ def assert_equal_in_group(tensor: Tensor, process_group: ProcessGroup = None): def check_state_dict_equal(d1: OrderedDict, d2: OrderedDict, ignore_device: bool = True): - for k, v in d1.items(): - if isinstance(v, dict): - check_state_dict_equal(v, d2[k]) - elif isinstance(v, list): - for i in range(len(v)): - if isinstance(v[i], torch.Tensor): + assert len(list(d1.keys())) == len(list(d2.keys())), \ + f"Number of keys unequal: {len(list(d1.keys()))} vs {len(list(d2.keys()))}" + for k, v1 in d1.items(): + assert k in d2 + v2 = d2[k] + if isinstance(v1, dict): + assert isinstance(v2, dict) + check_state_dict_equal(v1, v2, ignore_device) + elif isinstance(v1, list): + assert isinstance(v2, list) + for v1_i, v2_i in zip(v1, v2): + if isinstance(v1_i, torch.Tensor): + assert isinstance(v2_i, torch.Tensor) if not ignore_device: - v[i] = v[i].to("cpu") - d2[k][i] = d2[k][i].to("cpu") - assert torch.equal(v[i], d2[k][i]) + v1_i = v1_i.to("cpu") + v2_i = v2_i.to("cpu") + assert_close_loose(v1_i, v2_i) + elif isinstance(v1_i, dict): + assert isinstance(v2_i, dict) + check_state_dict_equal(v1_i, v2_i, ignore_device) else: - assert v[i] == d2[k][i] - elif isinstance(v, torch.Tensor): + assert v1_i == v2_i, f"{v1_i} not equals to {v2_i}" + elif isinstance(v1, torch.Tensor): + assert isinstance(v2, torch.Tensor) if not ignore_device: - v = v.to("cpu") - d2[k] = d2[k].to("cpu") - assert torch.equal(v, d2[k]) + v1 = v1.to("cpu") + v2 = v2.to("cpu") + assert_close_loose(v1, v2) else: - assert v == d2[k] + assert v1 == v2, f"{v1} not equals to {v2}" + + +def check_state_dict_equal_pytree(d1: OrderedDict, d2: OrderedDict, ignore_device: bool = True): + flat_d1, _ = tree_flatten(d1) + flat_d2, _ = tree_flatten(d2) + assert len(flat_d1) == len(flat_d2) + for v1, v2 in zip(flat_d1, flat_d2): + if isinstance(v1, torch.Tensor): + assert isinstance(v2, torch.Tensor) + if not ignore_device: + v1 = v1.to("cpu") + v2 = v2.to("cpu") + assert_close_loose(v1, v2) + else: + assert v1 == v2, f"{v1} not equals to {v2}" def assert_hf_output_close(out1: Any, diff --git a/colossalai/zero/gemini/gemini_optimizer.py b/colossalai/zero/gemini/gemini_optimizer.py index 267deb1e8..99aff6f1c 100644 --- a/colossalai/zero/gemini/gemini_optimizer.py +++ b/colossalai/zero/gemini/gemini_optimizer.py @@ -1,4 +1,6 @@ # this code is inspired by the DeepSpeed library and implemented with our own design from scratch +import copy +import gc import math import warnings from typing import Any, Dict, Set, Tuple @@ -101,6 +103,11 @@ class ZeroOptimizer(ColossalaiOptimizer): self.clipping_flag = clipping_norm > 0.0 self.max_norm = clipping_norm self.verbose = verbose + self.param_groups_backup = list() + + # Mapping from integer id to real/fake param tensor, used for checkpointing. + self.id_to_real_params: Dict[int, Parameter] = dict() + self.id_to_fake_params: Dict[int, Parameter] = dict() if self.clipping_flag: assert norm_type == 2.0, "ZeroOptimizer only supports L2 norm now" @@ -301,25 +308,352 @@ class ZeroOptimizer(ColossalaiOptimizer): end = min(local_chunk.shard_size, param_info.end - local_chunk.shard_begin) return begin, end + param_id = -1 for group in self.optim.param_groups: fake_params_list = list() - + group_backup = {k: v for k, v in group.items() if k != 'params'} + group_ids = [] for param in group['params']: + + # Record the mapping of id to current param. + param_id += 1 + self.id_to_real_params[param_id] = param + group_ids.append(param_id) + + # If current param is controlled by current process, add it to fake_param. if is_ddp_ignored(param): continue chunk16 = self.chunk_manager.get_chunk(param) range_pair = get_range_pair(chunk16, param) if range_pair[0] >= range_pair[1]: continue - grad_device = self.module.grads_device[param] fake_param = torch.nn.Parameter(torch.empty([0], device=grad_device)) self.param_to_chunk32[fake_param] = chunk16.paired_chunk self.param_to_range[fake_param] = range_pair - + self.id_to_fake_params[param_id] = fake_param fake_params_list.append(fake_param) + # Update self.optim.param_groups as well as backup group. group['params'] = fake_params_list + group_backup['params'] = group_ids + self.param_groups_backup.append(group_backup) + + def get_offsets(self, param_id: int) -> tuple: + ''' + Args: + param_id(int): The id of parameter. + + Returns: + chunk_offset(int): Offset of parameter inside the chunk. + shard_offset(int): Offset of its optimizer state shard + relative to the whole optimizer state. + shard_size(int): Length of parameter shard owned by current process. + ''' + + if param_id not in self.id_to_fake_params: + return -1, -1, -1 + fake_param = self.id_to_fake_params[param_id] + chunk = self.param_to_chunk32[fake_param].paired_chunk + param = self.id_to_real_params[param_id] + param_info = chunk.tensors_info[param] + + begin_in_chunk, end_in_chunk = self.param_to_range[fake_param] + chunk_offset = begin_in_chunk + shard_offset = begin_in_chunk + chunk.shard_begin - param_info.offset + shard_size = end_in_chunk - begin_in_chunk + assert chunk_offset >= 0 and shard_offset >= 0 + + return chunk_offset, shard_offset, shard_size + + def collect_states(self, param_id: int, only_rank_0: bool = True) -> dict: + """ + Args: + param_id (int): id of the parameter whose state is to be gathered at master rank. + only_rank_0(bool): if True, states will be collected only on master rank, otherwise collected on every rank. + + Returns: + collected_states(dict): the gathered optimzier state of parameter with given id + if this method is called by master rank, otherwise an empty dict. + + This method can work only when called by all processes simultaneously. + """ + + # Get param & chunk & process group. + param = self.id_to_real_params[param_id] + fake_param = self.id_to_fake_params.get(param_id, None) + chunk = self.chunk_manager.get_chunk(param) + process_group = chunk.torch_pg + rank = dist.get_rank(process_group) + master_rank = 0 + collected_states = {} + + # Fetch names of states through all_gather. + local_state_names = None + if fake_param is not None: + local_state_names = list(self.optim.state[fake_param].keys()) + gathered_state_names = [None for _ in range(dist.get_world_size(process_group))] + dist.barrier() + dist.all_gather_object(gathered_state_names, local_state_names) + state_names = None + for names in gathered_state_names: + if names is not None: + # Assume different devices share the same set of state names if they have. + state_names = copy.deepcopy(names) + break + + # Directly return if this parameter doesn't have optimizer states. + # e.g. parameter freezed/layer dropped + if state_names is None: + return collected_states + + # Boolean variable is_collector indicates that whether the current rank + # needs to gather the whole optimizer states. + # Only master rank is collector when only_rank_0 is True. + # Every rank is collector when only_rank_0 is False. + is_collector = (rank == master_rank) or (not only_rank_0) + + # If the chunk is kept gathered, + # the parameteres are treated the same as that of those in strict DDP during training. + # So states can be directly fetched from current device. + if chunk.keep_gathered: + assert param_id in self.id_to_fake_params + if is_collector: + states = self.optim.state[fake_param] + for state_name in state_names: + if state_name == 'step': + # To keep aligned with pytorch, state 'step' is stored as a pytorch tensor with type float32. + collected_states[state_name] = torch.tensor(states['step'], + dtype=torch.float32, + requires_grad=False).cpu() + else: + collected_states[state_name] = states[state_name].detach().clone().to(torch.float32).cpu() + return collected_states + + # Check whether the param with given id is managed by current process. + own_param = param_id in self.id_to_fake_params + + # Collector gets prepared for state collecting. + if is_collector: + for state_name in state_names: + if state_name == 'step': + # To keep aligned with pytorch, state 'step' is stored as a pytorch tensor with type float32. + collected_states[state_name] = torch.tensor(0.0, dtype=torch.float32, requires_grad=False).cpu() + else: + collected_states[state_name] = torch.zeros(param.numel(), dtype=torch.float32, + requires_grad=False).cpu() + + # Materials for gathering, including compacted state tensors, and the offset of shard inside each state. + compacted_states = self.pack_optimizer_states_to_tensor(param_id, state_names) if own_param else None + _, shard_offset, shard_size = self.get_offsets(param_id) + + # Collectors gather state shards through all_gathering. + gathered_state_shards = [None for _ in range(dist.get_world_size(process_group))] + + dist.barrier() + dist.all_gather_object(gathered_state_shards, [compacted_states, shard_offset, shard_size]) + + if is_collector: + for state_shard in gathered_state_shards: + compacted_states = state_shard[0] + shard_offset = state_shard[1] + shard_size = state_shard[2] + if compacted_states is None: + continue + self.load_from_compacted_states(compacted_states, collected_states, state_names, shard_offset, + shard_size) + + # Clean gathered states + for state_shard in gathered_state_shards: + del state_shard[0] + gc.collect() + + # Reshape tensors + if is_collector: + for state_name, state_tensor in collected_states.items(): + if state_tensor.numel() == param.numel(): + collected_states[state_name] = torch.reshape(state_tensor, param.shape) + + return collected_states + + def pack_optimizer_states_to_tensor(self, + param_id: int, + state_names: list, + device: torch.device = torch.device('cuda'), + dtype: torch.dtype = torch.float32) -> torch.Tensor: + ''' + With param id given, pack its optimizer states into a compact tensor and return. + ''' + if param_id not in self.id_to_fake_params: + return None + + fake_param = self.id_to_fake_params[param_id] + param_range = self.param_to_range[fake_param] + states = self.optim.state[fake_param] + shard_size = param_range[1] - param_range[0] + compacted_size = 0 + for name in state_names: + if name == 'step': + compacted_size += 1 + else: + compacted_size += shard_size + compacted_states = torch.zeros(compacted_size, dtype=dtype, device=device, requires_grad=False) + + next_state_offset = 0 + for state_name, state_tensor in states.items(): + # State 'step' needs special operation. + if state_name == 'step': + if isinstance(state_tensor, torch.Tensor): + compacted_states[next_state_offset] = state_tensor[0].item() + else: + assert isinstance(state_tensor, int) + compacted_states[next_state_offset] = state_tensor + next_state_offset += 1 + else: + assert state_tensor.numel() == shard_size + compacted_states[next_state_offset:next_state_offset + shard_size].copy_(state_tensor) + next_state_offset += shard_size + + return compacted_states + + def load_from_compacted_states(self, compacted_states: torch.Tensor, collected_states: dict, state_names: list, + shard_start: int, shard_size: int): + ''' + Given a tensor carrying compacted optimizer states, + update these states to collected_states. + ''' + shard_end = shard_start + shard_size + next_state_offset = 0 + + for state_name in state_names: + if state_name == 'step': + collected_states['step'].data = torch.tensor(compacted_states[next_state_offset].item(), + dtype=torch.float32, + requires_grad=False).cpu() + next_state_offset += 1 + else: + target_segment = collected_states[state_name][shard_start:shard_end] + target_segment.copy_(compacted_states[next_state_offset:next_state_offset + shard_size]) + next_state_offset += shard_size + + def state_dict(self, only_rank_0: bool = True) -> dict: + """ + Args: + only_rank_0 (bool): a boolean value indicating whether the state_dict is collected + only on rank 0, dafault to True. + + Returns: + The complete state of the optimizer as a :class:`dict`. + It contains two entries: + + * state - a dict holding current optimization state. Its content + differs between optimizer classes. + * param_groups - a list containing all parameter groups where each + parameter group is a dict. + + Warning: This method will gather and return the whole optimizer state_dict, + so it should be called only when memory resources are abundant. + """ + state_dict = {} + state_dict['param_groups'] = copy.deepcopy(self.param_groups_backup) + + torch_special_hyperparameters = { + 'amsgrad': False, + 'maximize': False, + 'foreach': None, + 'capturable': False, + 'differentiable': False, + 'fused': False + } + + for group in state_dict['param_groups']: + for k, v in torch_special_hyperparameters.items(): + if k not in group: + group[k] = v + + # Collect optimizer states. + state_dict['state'] = dict() + for param_id in self.id_to_real_params.keys(): + dist.barrier() + state_dict['state'][param_id] = self.collect_states(param_id=param_id, only_rank_0=only_rank_0) + return state_dict + + def load_param_groups(self, saved_param_groups: list): + """ + Load saved_param_groups into + self.param_groups and self.param_groups_backup + """ + self.param_groups_backup = copy.deepcopy(saved_param_groups) + + # discard the older param_groups + self.optim.param_groups = [] + + for group in saved_param_groups: + fake_params_list = list() + updated_group = {k: v for k, v in group.items() if k != 'params'} + for param_id in group['params']: + if param_id not in self.id_to_fake_params: + continue + fake_param = self.id_to_fake_params[param_id] + fake_params_list.append(fake_param) + updated_group['params'] = fake_params_list + self.optim.param_groups.append(updated_group) + + def load_single_param_states(self, param_id: int, saved_states: dict): + """ + Load saved optimizer states into parameter with given id. + """ + + def cast(param, state_range, value, key=None): + """ + Make a copy of the needed segment of value and cast it to device of param. + """ + assert isinstance(value, torch.Tensor) + ret_val = value + if (key == "step"): + assert value.numel() == 1 + ret_val = int(value.item()) + else: + state_start, state_end = state_range + ret_val = torch.zeros(state_end - state_start, + dtype=torch.float32, + device=param.device, + requires_grad=False) + ret_val.copy_(value.flatten()[state_start:state_end]) + return ret_val + + assert param_id in self.id_to_fake_params + fake_param = self.id_to_fake_params[param_id] + _, state_offset, param_size = self.get_offsets(param_id) + state_range = (state_offset, state_offset + param_size) + + # Copy states assigned to param (and cast tensors to appropriate types). + updated_states = dict() + for k, v in saved_states.items(): + updated_states[k] = cast(fake_param, state_range, v, k) + del v # clean loaded states + self.optim.state[fake_param].update(updated_states) + + def load_state_dict(self, state_dict: dict): + """Loads optimizer state from whole optimizer state_dict. + During loading, filter out the part of states not considered by current process. + + Args: + state_dict (dict): optimizer state. Should be an object returned + from a call to :meth:`state_dict`. + """ + assert 'param_groups' in state_dict + self.load_param_groups(state_dict['param_groups']) + + state = state_dict['state'] + + for param_id, param_states in state.items(): + if param_id in self.id_to_fake_params: + self.load_single_param_states(param_id, param_states) + + # Epilogue for pytorch optimizer. + self.optim._hook_for_profile() # To support multiprocessing pickle/unpickle. + self.optim.defaults.setdefault('differentiable', False) class GeminiAdamOptimizer(ZeroOptimizer): diff --git a/tests/test_checkpoint_io/test_gemini_checkpoint_io.py b/tests/test_checkpoint_io/test_gemini_checkpoint_io.py index 602cf468c..0235ff2e2 100644 --- a/tests/test_checkpoint_io/test_gemini_checkpoint_io.py +++ b/tests/test_checkpoint_io/test_gemini_checkpoint_io.py @@ -8,15 +8,18 @@ from utils import shared_tempdir import colossalai from colossalai.booster import Booster from colossalai.booster.plugin import GeminiPlugin -from colossalai.booster.plugin.gemini_plugin import GeminiCheckpointIO from colossalai.nn.optimizer import HybridAdam -from colossalai.testing import check_state_dict_equal, parameterize, rerun_if_address_is_in_use, spawn -from colossalai.zero import ZeroDDP -from colossalai.zero.gemini.chunk import ChunkManager, search_chunk_configuration -from colossalai.zero.gemini.gemini_mgr import GeminiManager +from colossalai.testing import ( + check_state_dict_equal, + clear_cache_before_run, + parameterize, + rerun_if_address_is_in_use, + spawn, +) from tests.kit.model_zoo import model_zoo +@clear_cache_before_run() @parameterize('placement_policy', ['cuda', 'cpu']) @parameterize('model_name', ['transformers_bert_for_sequence_classification']) @parameterize('use_safetensors', [False, True]) @@ -29,33 +32,33 @@ def exam_state_dict_with_origin(placement_policy, model_name, use_safetensors: b pretrained_path = os.path.join(tempdir, 'pretrained') bert_model.config.save_pretrained(save_directory=pretrained_path) - # TODO(ver217): use boost api - config_dict, *_ = search_chunk_configuration(bert_model, search_range_m=1, search_interval=100) - chunk_manager = ChunkManager(config_dict) - gemini_manager = GeminiManager(placement_policy, chunk_manager) - bert_model = ZeroDDP(bert_model, gemini_manager) - bert_model.train() - - ckpt_io = GeminiCheckpointIO() + plugin = GeminiPlugin(placement_policy=placement_policy) + booster = Booster(plugin=plugin) + bert_model, _, _, _, _ = booster.boost(bert_model) model_size = sum(p.numel() * p.element_size() for p in bert_model.parameters()) / 1024**2 - ckpt_io.save_model(bert_model, (pretrained_path), + + booster.save_model(bert_model, + pretrained_path, True, True, '', (model_size / 3), use_safetensors=use_safetensors) dist.barrier() + new_bert_model = BertForSequenceClassification.from_pretrained(pretrained_path) - check_state_dict_equal(bert_model.state_dict(only_rank_0=False, dtype=torch.float32), + check_state_dict_equal(bert_model.unwrap().state_dict(only_rank_0=False, dtype=torch.float32), new_bert_model.state_dict(), False) +@clear_cache_before_run() @parameterize('placement_policy', ['cuda', 'cpu']) -@parameterize('shard', [True, False]) +@parameterize('shard', [False]) @parameterize('model_name', ['transformers_gpt']) -def exam_state_dict(placement_policy, shard: bool, model_name: str): +@parameterize('size_per_shard', [32]) +def exam_state_dict(placement_policy, shard: bool, model_name: str, size_per_shard: int): (model_fn, data_gen_fn, output_transform_fn, _, _) = next(iter(model_zoo.get_sub_registry(model_name).values())) criterion = lambda x: x.mean() - plugin = GeminiPlugin(placement_policy=placement_policy) + plugin = GeminiPlugin(placement_policy=placement_policy, precision="fp16", initial_scale=(2**14)) booster = Booster(plugin=plugin) model = model_fn() @@ -78,18 +81,32 @@ def exam_state_dict(placement_policy, shard: bool, model_name: str): with shared_tempdir() as tempdir: model_ckpt_path = f"{tempdir}/model" optimizer_ckpt_path = f"{tempdir}/optimizer" - booster.save_model(model, model_ckpt_path) - if not shard: - # TODO(ver217): optimizer checkpointing is not supported for sharded checkpoint - booster.save_optimizer(optimizer, optimizer_ckpt_path) + booster.save_model(model, model_ckpt_path, shard=shard, size_per_shard=size_per_shard) + + booster.save_optimizer(optimizer, optimizer_ckpt_path, shard=shard, size_per_shard=size_per_shard) dist.barrier() booster.load_model(new_model, model_ckpt_path) check_state_dict_equal(model.unwrap().state_dict(only_rank_0=False), new_model.unwrap().state_dict(only_rank_0=False), False) - if not shard: - booster.load_optimizer(new_optimizer, optimizer_ckpt_path) - check_state_dict_equal(optimizer.state_dict(), new_optimizer.state_dict(), False) + + booster.load_optimizer(new_optimizer, optimizer_ckpt_path) + check_state_dict_equal(optimizer.unwrap().state_dict(only_rank_0=False), + new_optimizer.unwrap().state_dict(only_rank_0=False), False) + + # Check the new model/optimizer can successfully run. + data = data_gen_fn() + data = { + k: v.to('cuda') if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__ else v for k, v in data.items() + } + output = new_model(**data) + output = output_transform_fn(output) + output_key = list(output.keys())[0] + loss = criterion(output[output_key]) + booster.backward(loss, new_optimizer) + new_optimizer.step() + booster.save_model(new_model, model_ckpt_path, shard=shard) + booster.save_optimizer(new_optimizer, optimizer_ckpt_path, shard=shard) def run_dist(rank, world_size, port): @@ -100,7 +117,7 @@ def run_dist(rank, world_size, port): @pytest.mark.dist -@pytest.mark.parametrize('world_size', [2]) +@pytest.mark.parametrize('world_size', [1, 2]) @rerun_if_address_is_in_use() def test_gemini_ckpIO(world_size): spawn(run_dist, world_size) diff --git a/tests/test_checkpoint_io/test_gemini_torch_compability.py b/tests/test_checkpoint_io/test_gemini_torch_compability.py new file mode 100644 index 000000000..b34e3e3a1 --- /dev/null +++ b/tests/test_checkpoint_io/test_gemini_torch_compability.py @@ -0,0 +1,171 @@ +import pytest +import torch +import torch.distributed as dist +from torch.optim import Adam +from utils import shared_tempdir + +import colossalai +from colossalai.booster import Booster +from colossalai.booster.plugin import GeminiPlugin, TorchDDPPlugin +from colossalai.nn.optimizer import HybridAdam +from colossalai.testing import ( + check_state_dict_equal, + clear_cache_before_run, + parameterize, + rerun_if_address_is_in_use, + spawn, +) +from tests.kit.model_zoo import model_zoo + + +@clear_cache_before_run() +@parameterize('shard', [False]) +@parameterize('model_name', ['transformers_gpt']) +def exam_torch_load_from_gemini(shard: bool, model_name: str): + + (model_fn, data_gen_fn, output_transform_fn, _, _) = next(iter(model_zoo.get_sub_registry(model_name).values())) + criterion = lambda x: x.mean() + plugin = GeminiPlugin(precision="fp16", initial_scale=(2**14)) + booster = Booster(plugin=plugin) + + model = model_fn() + optimizer = HybridAdam(model.parameters(), lr=0.001) + model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion) + + data = data_gen_fn() + data = {k: v.to('cuda') if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__ else v for k, v in data.items()} + output = model(**data) + output = output_transform_fn(output) + output_key = list(output.keys())[0] + loss = criterion(output[output_key]) + + booster.backward(loss, optimizer) + optimizer.step() + + with shared_tempdir() as tempdir: + model_ckpt_path = f"{tempdir}/model" + optimizer_ckpt_path = f"{tempdir}/optimizer" + + booster.save_model(model, model_ckpt_path, shard=shard) + booster.save_optimizer(optimizer, optimizer_ckpt_path, shard=shard) + dist.barrier() + + new_model = model_fn() + new_optimizer = Adam(new_model.parameters(), lr=0.001) + new_plugin = TorchDDPPlugin() + new_booster = Booster(plugin=new_plugin) + new_model, new_optimizer, criterion, _, _ = new_booster.boost(new_model, new_optimizer, criterion) + + # Loading HybridAdam states to torch.Adam + new_booster.load_model(new_model, model_ckpt_path, strict=True) + + # Add prefix to get aligned with pytorch parameter names. + check_state_dict_equal( + model.unwrap().state_dict(only_rank_0=False, prefix='module.module.', dtype=torch.float32), + new_model.state_dict(), False) + + new_booster.load_optimizer(new_optimizer, optimizer_ckpt_path) + check_state_dict_equal(optimizer.unwrap().state_dict(only_rank_0=False), new_optimizer.state_dict(), False) + + # Check the new model/optimizer can successfully run. + data = data_gen_fn() + data = { + k: v.to('cuda') if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__ else v for k, v in data.items() + } + output = new_model(**data) + output = output_transform_fn(output) + output_key = list(output.keys())[0] + loss = criterion(output[output_key]) + new_booster.backward(loss, new_optimizer) + new_optimizer.step() + new_booster.save_model(new_model, model_ckpt_path, shard=shard) + new_booster.save_optimizer(new_optimizer, optimizer_ckpt_path, shard=shard) + + +@clear_cache_before_run() +@parameterize('shard', [False]) +@parameterize('model_name', ['transformers_gpt']) +def exam_gemini_load_from_torch(shard: bool, model_name: str): + + (model_fn, data_gen_fn, output_transform_fn, _, _) = next(iter(model_zoo.get_sub_registry(model_name).values())) + criterion = lambda x: x.mean() + plugin = TorchDDPPlugin() + booster = Booster(plugin=plugin) + + model = model_fn() + optimizer = Adam(model.parameters(), lr=0.001) + model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion) + + data = data_gen_fn() + data = {k: v.to('cuda') if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__ else v for k, v in data.items()} + output = model(**data) + output = output_transform_fn(output) + output_key = list(output.keys())[0] + loss = criterion(output[output_key]) + + booster.backward(loss, optimizer) + optimizer.step() + + with shared_tempdir() as tempdir: + model_ckpt_path = f"{tempdir}/model" + optimizer_ckpt_path = f"{tempdir}/optimizer" + + booster.save_model(model, model_ckpt_path, shard=shard) + booster.save_optimizer(optimizer, optimizer_ckpt_path, shard=shard) + dist.barrier() + + new_model = model_fn() + new_optimizer = HybridAdam(new_model.parameters(), lr=0.001) + new_plugin = GeminiPlugin() + new_booster = Booster(plugin=new_plugin) + new_model, new_optimizer, criterion, _, _ = new_booster.boost(new_model, new_optimizer, criterion) + + # Loading torch.Adam states to HybridAdam + new_booster.load_model(new_model, model_ckpt_path, strict=True) + + # Add prefix to get aligned with pytorch parameter names. + check_state_dict_equal( + new_model.unwrap().state_dict(only_rank_0=False, prefix='module.module.', dtype=torch.float32), + model.state_dict(), False) + + new_booster.load_optimizer(new_optimizer, optimizer_ckpt_path) + old_state_dict = optimizer.state_dict() + new_state_dict = new_optimizer.unwrap().state_dict(only_rank_0=False) + + # Comparison of param_groups needs special care here, + # since not all hyperparameters in Adam are used by HybridAdam + hyperparameters_to_examine = ['params', 'lr', 'betas', 'eps', 'weight_decay'] + for old_group, new_group in zip(old_state_dict['param_groups'], new_state_dict['param_groups']): + for k in hyperparameters_to_examine: + assert k in old_group and k in new_group, \ + f"Old group's keys: {list(old_group.keys())}, New group's keys: {list(new_group.keys())}" + assert old_group[k] == new_group[k] + check_state_dict_equal(old_state_dict['state'], new_state_dict['state'], False) + + # Check the new model/optimizer can successfully run. + data = data_gen_fn() + data = { + k: v.to('cuda') if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__ else v for k, v in data.items() + } + output = new_model(**data) + output = output_transform_fn(output) + output_key = list(output.keys())[0] + loss = criterion(output[output_key]) + new_booster.backward(loss, new_optimizer) + new_optimizer.step() + new_booster.save_model(new_model, model_ckpt_path, shard=shard) + new_booster.save_optimizer(new_optimizer, optimizer_ckpt_path, shard=shard) + + +def run_dist(rank, world_size, port): + config = {} + colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + exam_torch_load_from_gemini() + exam_gemini_load_from_torch() + + +@pytest.mark.dist +@pytest.mark.parametrize('world_size', [1, 2]) +@rerun_if_address_is_in_use() +def test_gemini_ckpIO(world_size): + spawn(run_dist, world_size)