From f313babd11f8137c2496e7dc54c6b61604cd3672 Mon Sep 17 00:00:00 2001
From: Hongxin Liu <lhx0217@gmail.com>
Date: Mon, 17 Apr 2023 17:11:09 +0800
Subject: [PATCH] [gemini] support save state dict in shards (#3581)

* [gemini] support state dict shard

* [gemini] add test state dict shard

* [gemini] polish docstr

* [gemini] fix merge

* [gemini] polish code
---
 colossalai/zero/gemini/gemini_ddp.py          | 129 ++++++++++++++++--
 .../test_zeroddp_state_dict_shard.py          |  56 ++++++++
 2 files changed, 172 insertions(+), 13 deletions(-)
 create mode 100644 tests/test_zero/test_gemini/test_zeroddp_state_dict_shard.py

diff --git a/colossalai/zero/gemini/gemini_ddp.py b/colossalai/zero/gemini/gemini_ddp.py
index 2e35be066..9a193310b 100644
--- a/colossalai/zero/gemini/gemini_ddp.py
+++ b/colossalai/zero/gemini/gemini_ddp.py
@@ -1,12 +1,13 @@
 import itertools
 from collections import OrderedDict
 from functools import partial
-from typing import Dict, List, Optional, Union
+from typing import Dict, Iterator, List, Optional, Union
 
 import torch
 import torch.distributed as dist
 import torch.nn as nn
 
+from colossalai.checkpoint_io.utils import calculate_tensor_size
 from colossalai.logging import get_dist_logger
 from colossalai.nn.parallel.data_parallel import ColoDDP, _cast_float, free_storage
 from colossalai.tensor import ProcessGroup as ColoProcessGroup
@@ -228,6 +229,32 @@ class ZeroDDP(ColoDDP):
                 destination = hook_result
         return destination
 
+    def _get_chunk_to_save_data(self, chunk: Chunk, only_rank_0: bool) -> Dict:
+        """
+        get gathered chunk content.
+
+        Args:
+            chunk (Chunk): a chunk
+            only_rank_0 (bool): whether to only save data on rank 0
+
+        Returns:
+            Dict: a dict whose key is param name and value is param with correct payload
+        """
+        # save parameters
+        chunk_to_save_data = dict()
+        temp_chunk = get_temp_total_chunk_on_cuda(chunk)
+        for tensor, tensor_info in chunk.tensors_info.items():
+            record_tensor = torch.empty([0])
+            record_flag = (not only_rank_0) | (dist.get_rank(chunk.torch_pg) == 0)
+            if record_flag:
+                record_tensor = temp_chunk[tensor_info.offset:tensor_info.end].view(tensor.shape).cpu()
+
+            assert tensor not in chunk_to_save_data
+            chunk_to_save_data[tensor] = record_tensor
+
+        del temp_chunk
+        return chunk_to_save_data
+
     def _get_param_to_save_data(self, param_list: List[torch.nn.Parameter], only_rank_0: bool) -> Dict:
         """
         get param content from chunks.
@@ -243,18 +270,7 @@ class ZeroDDP(ColoDDP):
         param_to_save_data = dict()
         chunk_list = self.chunk_manager.get_chunks(param_list)
         for chunk in chunk_list:
-            temp_chunk = get_temp_total_chunk_on_cuda(chunk)
-
-            for tensor, tensor_info in chunk.tensors_info.items():
-                record_tensor = torch.empty([0])
-                record_flag = (not only_rank_0) | (dist.get_rank(chunk.torch_pg) == 0)
-                if record_flag:
-                    record_tensor = temp_chunk[tensor_info.offset:tensor_info.end].view(tensor.shape).cpu()
-
-                assert tensor not in param_to_save_data
-                param_to_save_data[tensor] = record_tensor
-
-            del temp_chunk
+            param_to_save_data.update(self._get_chunk_to_save_data(chunk, only_rank_0))
         return param_to_save_data
 
     def _save_to_state_dict(self, destination, prefix, keep_vars, only_rank_0=True):
@@ -554,6 +570,93 @@ class ZeroDDP(ColoDDP):
         p.__class__ = ColoParameter
         p.__init__(p, requires_grad=requires_grad)
 
+    def state_dict_shard(self,
+                         prefix: str = '',
+                         keep_vars: bool = False,
+                         max_shard_size: int = 1024,
+                         only_rank_0: bool = True) -> Iterator[OrderedDict]:
+        """Returns dictionaries containing a whole state of the module one by one. The max size of dictionary shard is specified by ``max_shard_size``.
+
+        Both parameters and persistent buffers (e.g. running averages) are included.
+        Keys are corresponding parameter and buffer names.
+        Parameters and buffers set to ``None`` are not included.
+
+        Args:
+            prefix (str, optional): the prefix for parameters and buffers used in this
+                module. Defaults to ''.
+            keep_vars (bool, optional): whether to keep variables. Defaults to False.
+            max_shard_size (int, optional): max size of state dict shard (in MB). Defaults to 1024.
+            only_rank_0 (bool, optional): only get data on rank0. Defaults to True.
+
+
+        Yields:
+            Iterator[OrderedDict]: A generator of state dict shard
+        """
+        sharder = _StateDictSharder(max_shard_size)
+
+        # get the mapping between copies and fp16 parameters
+        fp16_to_fp32 = dict()
+        for p, fp32_p in zip(self.fp16_params, self.fp32_params):
+            fp16_to_fp32[p] = fp32_p
+
+        # key is fp32 param, and value is gathered param on CPU
+        gathered_param_buffer = dict()
+        for name, param in self.name2param.items():
+            if param is not None:
+                if is_ddp_ignored(param):
+                    # deal with ddp ignored parameters
+                    gathered_param = param if keep_vars else param.detach()
+                else:
+                    fp32_param = fp16_to_fp32[param]
+                    if fp32_param not in gathered_param_buffer:
+                        chunk = self.chunk_manager.get_chunk(fp32_param)
+                        gathered_param_buffer.update(self._get_chunk_to_save_data(chunk, only_rank_0))
+                    gathered_param = gathered_param_buffer.pop(fp32_param)
+
+                block = sharder.append(prefix + name, gathered_param)
+                if block is not None:
+                    yield block
+
+        del fp16_to_fp32
+        del gathered_param_buffer
+
+        # save all buffers
+        for name, buf in self.named_buffers():
+            if buf is not None and name not in self._non_persistent_buffers_set:
+                buffer = buf if keep_vars else buf.detach()
+                block = sharder.append(prefix + name, buffer)
+                if block is not None:
+                    yield block
+        # save extra states
+        extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX
+        if getattr(self.__class__, "get_extra_state",
+                   torch.nn.Module.get_extra_state) is not torch.nn.Module.get_extra_state:
+            extra_state = self.get_extra_state()
+            block = sharder.append(extra_state_key, extra_state)
+            if block is not None:
+                yield block
+
+        yield sharder.current_block
+
+
+class _StateDictSharder:
+
+    def __init__(self, max_shard_size: int) -> None:
+        self.max_shard_size = max_shard_size
+        self.current_block = OrderedDict()
+        self.current_block_size = 0
+
+    def append(self, name: str, tensor: torch.Tensor) -> Optional[OrderedDict]:
+        tensor_size = calculate_tensor_size(tensor)
+        ret_block = None
+        if self.current_block_size + tensor_size > self.max_shard_size:
+            ret_block = self.current_block
+            self.current_block = OrderedDict()
+            self.current_block_size = 0
+        self.current_block[name] = tensor
+        self.current_block_size += tensor_size
+        return ret_block
+
 
 class GeminiDDP(ZeroDDP):
 
diff --git a/tests/test_zero/test_gemini/test_zeroddp_state_dict_shard.py b/tests/test_zero/test_gemini/test_zeroddp_state_dict_shard.py
new file mode 100644
index 000000000..96c26a1de
--- /dev/null
+++ b/tests/test_zero/test_gemini/test_zeroddp_state_dict_shard.py
@@ -0,0 +1,56 @@
+import pytest
+import torch
+from torch.testing import assert_close
+
+import colossalai
+from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
+from colossalai.utils.cuda import get_current_device
+from colossalai.zero import ColoInitContext, ZeroDDP
+from colossalai.zero.gemini.chunk import ChunkManager, search_chunk_configuration
+from colossalai.zero.gemini.gemini_mgr import GeminiManager
+from tests.components_to_test.registry import non_distributed_component_funcs
+
+
+@parameterize('placement_policy', ['cuda', 'cpu'])
+@parameterize('model_name', ['gpt2', 'bert'])
+def exam_state_dict(placement_policy, model_name: str):
+    get_components_func = non_distributed_component_funcs.get_callable(model_name)
+    model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
+
+    with ColoInitContext(device=get_current_device()):
+        model = model_builder()
+
+    model_size = sum(p.numel() * p.element_size() for p in model.parameters()) / 1024**2
+
+    config_dict, *_ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100)
+    chunk_manager = ChunkManager(config_dict)
+    gemini_manager = GeminiManager(placement_policy, chunk_manager)
+    model = ZeroDDP(model, gemini_manager)
+    model.train()
+
+    zero_dict = model.state_dict(only_rank_0=False)
+    accumulated_keys = set()
+    # ensure number of shards > 1
+    for shard in model.state_dict_shard(max_shard_size=(model_size / 3), only_rank_0=False):
+        for key, value in shard.items():
+            assert key not in accumulated_keys, f"key `{key}` is duplicated."
+            accumulated_keys.add(key)
+            assert key in zero_dict, f"{key} not in ZeRO dictionary."
+            assert torch.equal(value, zero_dict[key]), f"{key} not equal."
+
+
+def run_dist(rank, world_size, port):
+    config = {}
+    colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
+    exam_state_dict()
+
+
+@pytest.mark.dist
+@pytest.mark.parametrize('world_size', [1, 4])
+@rerun_if_address_is_in_use()
+def test_zero_ddp_state_dict_shard(world_size):
+    spawn(run_dist, world_size)
+
+
+if __name__ == '__main__':
+    test_zero_ddp_state_dict_shard(1)