From f313babd11f8137c2496e7dc54c6b61604cd3672 Mon Sep 17 00:00:00 2001 From: Hongxin Liu <lhx0217@gmail.com> Date: Mon, 17 Apr 2023 17:11:09 +0800 Subject: [PATCH] [gemini] support save state dict in shards (#3581) * [gemini] support state dict shard * [gemini] add test state dict shard * [gemini] polish docstr * [gemini] fix merge * [gemini] polish code --- colossalai/zero/gemini/gemini_ddp.py | 129 ++++++++++++++++-- .../test_zeroddp_state_dict_shard.py | 56 ++++++++ 2 files changed, 172 insertions(+), 13 deletions(-) create mode 100644 tests/test_zero/test_gemini/test_zeroddp_state_dict_shard.py diff --git a/colossalai/zero/gemini/gemini_ddp.py b/colossalai/zero/gemini/gemini_ddp.py index 2e35be066..9a193310b 100644 --- a/colossalai/zero/gemini/gemini_ddp.py +++ b/colossalai/zero/gemini/gemini_ddp.py @@ -1,12 +1,13 @@ import itertools from collections import OrderedDict from functools import partial -from typing import Dict, List, Optional, Union +from typing import Dict, Iterator, List, Optional, Union import torch import torch.distributed as dist import torch.nn as nn +from colossalai.checkpoint_io.utils import calculate_tensor_size from colossalai.logging import get_dist_logger from colossalai.nn.parallel.data_parallel import ColoDDP, _cast_float, free_storage from colossalai.tensor import ProcessGroup as ColoProcessGroup @@ -228,6 +229,32 @@ class ZeroDDP(ColoDDP): destination = hook_result return destination + def _get_chunk_to_save_data(self, chunk: Chunk, only_rank_0: bool) -> Dict: + """ + get gathered chunk content. + + Args: + chunk (Chunk): a chunk + only_rank_0 (bool): whether to only save data on rank 0 + + Returns: + Dict: a dict whose key is param name and value is param with correct payload + """ + # save parameters + chunk_to_save_data = dict() + temp_chunk = get_temp_total_chunk_on_cuda(chunk) + for tensor, tensor_info in chunk.tensors_info.items(): + record_tensor = torch.empty([0]) + record_flag = (not only_rank_0) | (dist.get_rank(chunk.torch_pg) == 0) + if record_flag: + record_tensor = temp_chunk[tensor_info.offset:tensor_info.end].view(tensor.shape).cpu() + + assert tensor not in chunk_to_save_data + chunk_to_save_data[tensor] = record_tensor + + del temp_chunk + return chunk_to_save_data + def _get_param_to_save_data(self, param_list: List[torch.nn.Parameter], only_rank_0: bool) -> Dict: """ get param content from chunks. @@ -243,18 +270,7 @@ class ZeroDDP(ColoDDP): param_to_save_data = dict() chunk_list = self.chunk_manager.get_chunks(param_list) for chunk in chunk_list: - temp_chunk = get_temp_total_chunk_on_cuda(chunk) - - for tensor, tensor_info in chunk.tensors_info.items(): - record_tensor = torch.empty([0]) - record_flag = (not only_rank_0) | (dist.get_rank(chunk.torch_pg) == 0) - if record_flag: - record_tensor = temp_chunk[tensor_info.offset:tensor_info.end].view(tensor.shape).cpu() - - assert tensor not in param_to_save_data - param_to_save_data[tensor] = record_tensor - - del temp_chunk + param_to_save_data.update(self._get_chunk_to_save_data(chunk, only_rank_0)) return param_to_save_data def _save_to_state_dict(self, destination, prefix, keep_vars, only_rank_0=True): @@ -554,6 +570,93 @@ class ZeroDDP(ColoDDP): p.__class__ = ColoParameter p.__init__(p, requires_grad=requires_grad) + def state_dict_shard(self, + prefix: str = '', + keep_vars: bool = False, + max_shard_size: int = 1024, + only_rank_0: bool = True) -> Iterator[OrderedDict]: + """Returns dictionaries containing a whole state of the module one by one. The max size of dictionary shard is specified by ``max_shard_size``. + + Both parameters and persistent buffers (e.g. running averages) are included. + Keys are corresponding parameter and buffer names. + Parameters and buffers set to ``None`` are not included. + + Args: + prefix (str, optional): the prefix for parameters and buffers used in this + module. Defaults to ''. + keep_vars (bool, optional): whether to keep variables. Defaults to False. + max_shard_size (int, optional): max size of state dict shard (in MB). Defaults to 1024. + only_rank_0 (bool, optional): only get data on rank0. Defaults to True. + + + Yields: + Iterator[OrderedDict]: A generator of state dict shard + """ + sharder = _StateDictSharder(max_shard_size) + + # get the mapping between copies and fp16 parameters + fp16_to_fp32 = dict() + for p, fp32_p in zip(self.fp16_params, self.fp32_params): + fp16_to_fp32[p] = fp32_p + + # key is fp32 param, and value is gathered param on CPU + gathered_param_buffer = dict() + for name, param in self.name2param.items(): + if param is not None: + if is_ddp_ignored(param): + # deal with ddp ignored parameters + gathered_param = param if keep_vars else param.detach() + else: + fp32_param = fp16_to_fp32[param] + if fp32_param not in gathered_param_buffer: + chunk = self.chunk_manager.get_chunk(fp32_param) + gathered_param_buffer.update(self._get_chunk_to_save_data(chunk, only_rank_0)) + gathered_param = gathered_param_buffer.pop(fp32_param) + + block = sharder.append(prefix + name, gathered_param) + if block is not None: + yield block + + del fp16_to_fp32 + del gathered_param_buffer + + # save all buffers + for name, buf in self.named_buffers(): + if buf is not None and name not in self._non_persistent_buffers_set: + buffer = buf if keep_vars else buf.detach() + block = sharder.append(prefix + name, buffer) + if block is not None: + yield block + # save extra states + extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX + if getattr(self.__class__, "get_extra_state", + torch.nn.Module.get_extra_state) is not torch.nn.Module.get_extra_state: + extra_state = self.get_extra_state() + block = sharder.append(extra_state_key, extra_state) + if block is not None: + yield block + + yield sharder.current_block + + +class _StateDictSharder: + + def __init__(self, max_shard_size: int) -> None: + self.max_shard_size = max_shard_size + self.current_block = OrderedDict() + self.current_block_size = 0 + + def append(self, name: str, tensor: torch.Tensor) -> Optional[OrderedDict]: + tensor_size = calculate_tensor_size(tensor) + ret_block = None + if self.current_block_size + tensor_size > self.max_shard_size: + ret_block = self.current_block + self.current_block = OrderedDict() + self.current_block_size = 0 + self.current_block[name] = tensor + self.current_block_size += tensor_size + return ret_block + class GeminiDDP(ZeroDDP): diff --git a/tests/test_zero/test_gemini/test_zeroddp_state_dict_shard.py b/tests/test_zero/test_gemini/test_zeroddp_state_dict_shard.py new file mode 100644 index 000000000..96c26a1de --- /dev/null +++ b/tests/test_zero/test_gemini/test_zeroddp_state_dict_shard.py @@ -0,0 +1,56 @@ +import pytest +import torch +from torch.testing import assert_close + +import colossalai +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn +from colossalai.utils.cuda import get_current_device +from colossalai.zero import ColoInitContext, ZeroDDP +from colossalai.zero.gemini.chunk import ChunkManager, search_chunk_configuration +from colossalai.zero.gemini.gemini_mgr import GeminiManager +from tests.components_to_test.registry import non_distributed_component_funcs + + +@parameterize('placement_policy', ['cuda', 'cpu']) +@parameterize('model_name', ['gpt2', 'bert']) +def exam_state_dict(placement_policy, model_name: str): + get_components_func = non_distributed_component_funcs.get_callable(model_name) + model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() + + with ColoInitContext(device=get_current_device()): + model = model_builder() + + model_size = sum(p.numel() * p.element_size() for p in model.parameters()) / 1024**2 + + config_dict, *_ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100) + chunk_manager = ChunkManager(config_dict) + gemini_manager = GeminiManager(placement_policy, chunk_manager) + model = ZeroDDP(model, gemini_manager) + model.train() + + zero_dict = model.state_dict(only_rank_0=False) + accumulated_keys = set() + # ensure number of shards > 1 + for shard in model.state_dict_shard(max_shard_size=(model_size / 3), only_rank_0=False): + for key, value in shard.items(): + assert key not in accumulated_keys, f"key `{key}` is duplicated." + accumulated_keys.add(key) + assert key in zero_dict, f"{key} not in ZeRO dictionary." + assert torch.equal(value, zero_dict[key]), f"{key} not equal." + + +def run_dist(rank, world_size, port): + config = {} + colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + exam_state_dict() + + +@pytest.mark.dist +@pytest.mark.parametrize('world_size', [1, 4]) +@rerun_if_address_is_in_use() +def test_zero_ddp_state_dict_shard(world_size): + spawn(run_dist, world_size) + + +if __name__ == '__main__': + test_zero_ddp_state_dict_shard(1)