From dbb32692d293d865640d857e8ee16d653147f7bc Mon Sep 17 00:00:00 2001
From: Hongxin Liu <lhx0217@gmail.com>
Date: Mon, 5 Jun 2023 14:20:47 +0800
Subject: [PATCH] [lazy] refactor lazy init (#3891)

* [lazy] remove old lazy init

* [lazy] refactor lazy init folder structure

* [lazy] fix lazy tensor deepcopy

* [test] update lazy init test
---
 colossalai/lazy/__init__.py                   |   6 +
 .../experimental.py => lazy/lazy_init.py}     |   9 +-
 colossalai/utils/model/lazy_init_context.py   | 242 ------------------
 colossalai/zero/gemini/gemini_ddp.py          |  58 +++--
 .../test_plugin/test_gemini_plugin.py         |   2 +-
 .../lazy_init_utils.py                        |  10 +-
 .../test_distribute.py                        |   2 +-
 .../test_models.py                            |   0
 tests/test_utils/test_lazy_init_ctx.py        |  51 ----
 9 files changed, 56 insertions(+), 324 deletions(-)
 create mode 100644 colossalai/lazy/__init__.py
 rename colossalai/{utils/model/experimental.py => lazy/lazy_init.py} (98%)
 delete mode 100644 colossalai/utils/model/lazy_init_context.py
 rename tests/{test_utils/test_lazy_init => test_lazy}/lazy_init_utils.py (85%)
 rename tests/{test_utils/test_lazy_init => test_lazy}/test_distribute.py (97%)
 rename tests/{test_utils/test_lazy_init => test_lazy}/test_models.py (100%)
 delete mode 100644 tests/test_utils/test_lazy_init_ctx.py

diff --git a/colossalai/lazy/__init__.py b/colossalai/lazy/__init__.py
new file mode 100644
index 000000000..4387107bf
--- /dev/null
+++ b/colossalai/lazy/__init__.py
@@ -0,0 +1,6 @@
+from .lazy_init import LazyInitContext, LazyTensor
+
+__all__ = [
+    'LazyInitContext',
+    'LazyTensor',
+]
diff --git a/colossalai/utils/model/experimental.py b/colossalai/lazy/lazy_init.py
similarity index 98%
rename from colossalai/utils/model/experimental.py
rename to colossalai/lazy/lazy_init.py
index bf3e3d05b..c1fda3c53 100644
--- a/colossalai/utils/model/experimental.py
+++ b/colossalai/lazy/lazy_init.py
@@ -350,7 +350,14 @@ class LazyTensor(torch.Tensor):
                 copied.requires_grad_()
             return copied
 
-        target = LazyTensor(factory_fn, meta_data=self._meta_data)
+        if self._materialized_data is not None:
+            # self is early materialized
+            copied = self._materialized_data.detach().clone()
+            if self.requires_grad:
+                copied.requires_grad_()
+            target = LazyTensor(lambda: None, concrete_data=copied)
+        else:
+            target = LazyTensor(factory_fn, meta_data=self._meta_data)
 
         memo[id(self)] = target
         return target
diff --git a/colossalai/utils/model/lazy_init_context.py b/colossalai/utils/model/lazy_init_context.py
deleted file mode 100644
index cf05f9660..000000000
--- a/colossalai/utils/model/lazy_init_context.py
+++ /dev/null
@@ -1,242 +0,0 @@
-#!/usr/bin/env python
-# coding: utf-8
-
-import inspect
-import types
-from typing import Callable, List
-
-import torch
-import torch.nn as nn
-
-from colossalai.tensor import ColoParameter, ColoTensor
-from colossalai.utils.model.utils import substitute_init_recursively
-
-
-class LazyInitContext():
-    """
-    A context to allow for lazy weight initialization of PyTorch modules. It intercepts the tensor
-    initialization functions for lazy initialization
-
-    Note:
-        This API is only experimental and subject to future changes.
-
-    Usage:
-        with LazyInitContext() as ctx:
-            model = nn.Linear(10, 10)
-            model.weight.zero_()
-
-        # make sure the weight is a meta tensor
-        assert model.weight.is_meta
-
-        # initialize weights
-        ctx.lazy_init_parameters(model)
-
-        # make sure the weight is not a meta tensor
-        # and initialized correctly
-        assert not model.weight.is_meta and torch.all(model.weight == 0)
-
-    Args:
-        to_meta (bool): optional, whether to initialize the model with meta tensors, default is True. This
-            argument exists for now because some corner cases such as self.weight = torch.zeros(...) cannot be captured yet.
-        extra_torch_tensor_func (List[str]): extra torch tensor functions related
-            to value setting, such as `zero_` and `triu_`. `zero_` is pre-added by default.
-    """
-
-    tensor_set_value_func = ['zero_', 'fill_']
-
-    def __init__(self, to_meta: bool = True, extra_torch_tensor_func: List[str] = None):
-        # TODO: hijack the torch constructor functions as well
-        self._to_meta = to_meta
-        self._intercepted_nn_init_func_cache = {}
-        self._nn_init_methods = self._get_nn_init_methods()
-        self._torch_mod_cls = torch.nn.modules.module.Module
-
-        if extra_torch_tensor_func:
-            # use tuple to remove duplicates
-            self._torch_tensor_funcs = tuple(self.tensor_set_value_func + extra_torch_tensor_func)
-        else:
-            self._torch_tensor_funcs = self.tensor_set_value_func
-
-    @property
-    def to_meta(self):
-        return self._to_meta
-
-    def _cache_init_func(self, func):
-        """
-        This method wraps the ``torch.nn.init`` method and torch tensor value-setting functions
-        so that the function call is cached instead of being executed.
-        """
-
-        def wrapped_init_func(tensor, *args, **kwargs):
-            if tensor not in self._intercepted_nn_init_func_cache:
-                self._intercepted_nn_init_func_cache[tensor] = []
-            self._intercepted_nn_init_func_cache[tensor].append((func, args, kwargs))
-
-        return wrapped_init_func
-
-    def _get_nn_init_methods(self):
-        """
-        This method looks for all available functions in the ``torch.nn.init``
-        module.
-        """
-        nn_init_method_names = dir(torch.nn.init)
-        nn_init_methods = []
-
-        # look for all methods in ``torch.nn.init`` module
-        for name in nn_init_method_names:
-            nn_init_methods.append((name, getattr(torch.nn.init, name)))
-
-        def _is_init_method(item):
-            name, func = item
-
-            if (not isinstance(func, types.FunctionType) or name.startswith('_') or not name.endswith('_')):
-                return False
-            else:
-                return True
-
-        # remove methods which are not init functions
-        nn_init_methods = list(filter(_is_init_method, nn_init_methods))
-        return nn_init_methods
-
-    def _wrap_module_init(self, func):
-        """
-        This method wraps the calls to the `__init__` of ``torch.nn.Module`` and replaces
-        the argument device with value 'meta' so that all modules are created as meta tensors.
-        """
-        has_device = 'device' in inspect.signature(func).parameters
-
-        def layer_lazy_init(module, *args, **kwargs):
-            # if this module contains device argument
-            # we set it to meta to initialize as meta backend
-            if has_device:
-                kwargs['device'] = 'meta'
-            func(module, *args, **kwargs)
-
-            # if device is not found, we intialize it and convert to meta
-            if not has_device:
-                module.to('meta')
-
-        return layer_lazy_init
-
-    def _get_tmp_origin_func_ref(self, name):
-        """
-        Generate a function name for consistency during caching and retrieving.
-        """
-        return f'_orig_{name}'
-
-    def _patch_nn_init_funcs(self):
-        # patch nn.init functions
-        for name, func in self._nn_init_methods:
-            setattr(torch.nn.init, name, self._cache_init_func(func))
-
-    def _unpatch_nn_init_funcs(self):
-        # unpatch nn.init functions
-        for name, func in self._nn_init_methods:
-            setattr(torch.nn.init, name, func)
-
-    def _patch_submodule_init(self):
-        # patch classes __init__ methods
-        def _activate_wrap_init(cls):
-            cls.__orig_init__ = cls.__init__
-            cls.__init__ = self._wrap_module_init(cls.__init__)
-
-        substitute_init_recursively(self._torch_mod_cls, _activate_wrap_init, set())
-
-    def _unpatch_submodule_init(self):
-
-        def _recover_orig_init(cls):
-            cls.__init__ = cls.__orig_init__
-
-        substitute_init_recursively(self._torch_mod_cls, _recover_orig_init, set())
-
-    def _patch_torch_tensor_funcs(self):
-        # patch tensor value-setting functions
-        for func_name in self._torch_tensor_funcs:
-            origin_func_name = self._get_tmp_origin_func_ref(func_name)
-            origin_func = getattr(torch.Tensor, func_name)
-            setattr(torch.Tensor, origin_func_name, origin_func)
-            setattr(torch.Tensor, func_name, self._cache_init_func(origin_func))
-
-    def _unpatch_torch_tensor_funcs(self):
-        for func_name in self._torch_tensor_funcs:
-            origin_func_name = self._get_tmp_origin_func_ref(func_name)
-            origin_func = getattr(torch.Tensor, origin_func_name)
-            setattr(torch.Tensor, func_name, origin_func)
-
-    def __enter__(self):
-        self._patch_torch_tensor_funcs()
-        self._patch_nn_init_funcs()
-
-        if self._to_meta:
-            self._patch_submodule_init()
-        return self
-
-    def __exit__(self, *args, **kwargs):
-        if self._to_meta:
-            self._unpatch_submodule_init()
-        self._unpatch_nn_init_funcs()
-        self._unpatch_torch_tensor_funcs()
-
-    def lazy_init_parameters(self, model: torch.nn.Module, device='cpu'):
-        """
-        Initialize the weights of the meta-tensor model.
-
-        Args:
-            model (`torch.nn.Module`): the model instantiated under the context.
-            device (str): the device on which weights are initialized
-
-        """
-
-        def _init_recursively(module: nn.Module):
-            # recursively initialize the module
-            for mod in module.children():
-                _init_recursively(mod)
-
-            # initialize and shard tensors directly attached to the current module
-            for name, param in module.named_parameters(recurse=False):
-                _init_and_shard(module, name, param)
-
-            for name, buf in module.named_buffers(recurse=False):
-                _init_and_shard(module, name, buf)
-
-        @torch.no_grad()
-        def _init_and_shard(module, name, tensor):
-            # check whether the tensor is a buffer or parameter
-            is_param = isinstance(tensor, nn.parameter.Parameter)
-
-            # get sharding spec
-            dist_spec = getattr(tensor, 'dist_spec', None)
-            pg = getattr(tensor, 'pg', None)
-            comp_spec = getattr(tensor, 'comp_spec', None)
-
-            # convert the tensor from meta to materialized one
-            if tensor.is_meta:
-                materialized_tensor = torch.empty_like(tensor, device=device)
-                # if this tensor is a meta tensor, it must have an init function
-                assert tensor in self._intercepted_nn_init_func_cache
-            else:
-                materialized_tensor = tensor
-
-            # apply init function
-            if tensor in self._intercepted_nn_init_func_cache:
-                init_func, args, kwargs = self._intercepted_nn_init_func_cache[tensor][-1]
-                init_func(materialized_tensor, *args, **kwargs)
-
-            # convert it to ColoTensor or ColoParameter
-            if is_param:
-                tensor = ColoParameter.from_torch_tensor(materialized_tensor, requires_grad=tensor.requires_grad)
-            else:
-                tensor = ColoTensor.from_torch_tensor(materialized_tensor)
-
-            # override the original tensor
-            with torch.no_grad():
-                setattr(module, name, tensor)
-
-            # apply sharding
-            if dist_spec:
-                tensor.process_group = pg
-                tensor.set_tensor_spec(dist_spec, comp_spec)
-
-        _init_recursively(model)
-
-        return model
diff --git a/colossalai/zero/gemini/gemini_ddp.py b/colossalai/zero/gemini/gemini_ddp.py
index 878c25be7..fd49362d6 100644
--- a/colossalai/zero/gemini/gemini_ddp.py
+++ b/colossalai/zero/gemini/gemini_ddp.py
@@ -2,13 +2,14 @@ import itertools
 from collections import OrderedDict
 from contextlib import nullcontext
 from functools import partial
-from typing import Dict, Iterator, List, Optional, Union, Tuple, Set
+from typing import Dict, Iterator, List, Optional, Set, Tuple, Union
 
 import torch
 import torch.distributed as dist
 import torch.nn as nn
 
 from colossalai.checkpoint_io.utils import calculate_tensor_size
+from colossalai.lazy import LazyTensor
 from colossalai.logging import get_dist_logger
 from colossalai.nn.parallel.data_parallel import ColoDDP, _cast_float, free_storage
 from colossalai.tensor import ProcessGroup as ColoProcessGroup
@@ -16,7 +17,6 @@ from colossalai.tensor import ReplicaSpec
 from colossalai.tensor.colo_parameter import ColoParameter, ColoTensor, ColoTensorSpec
 from colossalai.tensor.param_op_hook import ColoParamOpHookManager
 from colossalai.utils import get_current_device, is_ddp_ignored
-from colossalai.utils.model.experimental import LazyTensor
 
 from .chunk import Chunk, ChunkManager, TensorState, init_chunk_manager
 from .gemini_hook import GeminiZeROHook
@@ -96,34 +96,38 @@ class ZeroDDP(ColoDDP):
                 param_name = m_name + '.' + p_name if m_name else p_name
                 self.name2param[param_name] = p_var
         super().__init__(module, process_group=ColoProcessGroup())
-        self._non_persistent_buffers_set=self._get_non_persistent_buffers_set(module)
+        self._non_persistent_buffers_set = self._get_non_persistent_buffers_set(module)
         self._cast_buffers()
 
-    def _get_non_persistent_buffers_set(self, module, memo: Optional[Set[nn.Module]] = None, prefix: str = '', remove_duplicate: bool = True):
+    def _get_non_persistent_buffers_set(self,
+                                        module,
+                                        memo: Optional[Set[nn.Module]] = None,
+                                        prefix: str = '',
+                                        remove_duplicate: bool = True):
+        r"""
+        Args:
+            memo: a memo to store the set of modules already added to the result
+            prefix: a prefix that will be added to the name of the module
+            remove_duplicate: whether to remove the duplicated module instances in the result
+                or not
+        """
 
-            r"""
-            Args:
-                memo: a memo to store the set of modules already added to the result
-                prefix: a prefix that will be added to the name of the module
-                remove_duplicate: whether to remove the duplicated module instances in the result
-                    or not
-            """
-
-            if memo is None:
-                memo = set()
-            self_non_persistent_set = set()
-            if module not in memo:
-                if remove_duplicate:
-                    memo.add(module)
-                self_non_persistent_set = set(map(lambda key: prefix + ('.' if prefix else '') + key, module._non_persistent_buffers_set))
-                for name, sub_module in module._modules.items():
-                    if sub_module is None:
-                        continue
-                    submodule_prefix = prefix + ('.' if prefix else '') + name
-                    child_non_persistent_set = self._get_non_persistent_buffers_set(sub_module, memo, submodule_prefix, remove_duplicate)
-                    self_non_persistent_set = set.union(self_non_persistent_set, child_non_persistent_set)
-            return self_non_persistent_set
-    
+        if memo is None:
+            memo = set()
+        self_non_persistent_set = set()
+        if module not in memo:
+            if remove_duplicate:
+                memo.add(module)
+            self_non_persistent_set = set(
+                map(lambda key: prefix + ('.' if prefix else '') + key, module._non_persistent_buffers_set))
+            for name, sub_module in module._modules.items():
+                if sub_module is None:
+                    continue
+                submodule_prefix = prefix + ('.' if prefix else '') + name
+                child_non_persistent_set = self._get_non_persistent_buffers_set(sub_module, memo, submodule_prefix,
+                                                                                remove_duplicate)
+                self_non_persistent_set = set.union(self_non_persistent_set, child_non_persistent_set)
+        return self_non_persistent_set
 
     def _post_forward(self):
         """This function is only triggered for inference.
diff --git a/tests/test_booster/test_plugin/test_gemini_plugin.py b/tests/test_booster/test_plugin/test_gemini_plugin.py
index c7b3676fb..d606d6d89 100644
--- a/tests/test_booster/test_plugin/test_gemini_plugin.py
+++ b/tests/test_booster/test_plugin/test_gemini_plugin.py
@@ -8,10 +8,10 @@ import colossalai
 from colossalai.booster import Booster
 from colossalai.booster.plugin import GeminiPlugin
 from colossalai.fx import is_compatible_with_meta
+from colossalai.lazy.lazy_init import LazyInitContext
 from colossalai.nn.optimizer import HybridAdam
 from colossalai.tensor.colo_parameter import ColoParameter
 from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
-from colossalai.utils.model.experimental import LazyInitContext
 from colossalai.zero import ColoInitContext
 from tests.kit.model_zoo import model_zoo
 
diff --git a/tests/test_utils/test_lazy_init/lazy_init_utils.py b/tests/test_lazy/lazy_init_utils.py
similarity index 85%
rename from tests/test_utils/test_lazy_init/lazy_init_utils.py
rename to tests/test_lazy/lazy_init_utils.py
index aa87d32a8..85bfd0e27 100644
--- a/tests/test_utils/test_lazy_init/lazy_init_utils.py
+++ b/tests/test_lazy/lazy_init_utils.py
@@ -1,12 +1,13 @@
 import random
+from copy import deepcopy
 from typing import Any, Callable, Optional, Tuple
 
 import numpy as np
 import torch
 from packaging import version
 
+from colossalai.lazy.lazy_init import LazyInitContext, LazyTensor, _MyTensor
 from colossalai.tensor.d_tensor.layout_converter import to_global
-from colossalai.utils.model.experimental import LazyInitContext, LazyTensor, _MyTensor
 from tests.kit.model_zoo.registry import ModelAttribute
 
 SUPPORT_LAZY = version.parse(torch.__version__) >= version.parse('1.12.0')
@@ -31,6 +32,9 @@ def assert_model_equal(m1: torch.nn.Module, m2: torch.nn.Module) -> None:
         assert n1 == n2
         assert torch.equal(t1, t2), f'{n1} {t1} vs {t2}'
 
+    for p1, p2 in zip(m1.parameters(), m2.parameters()):
+        assert p1.requires_grad == p2.requires_grad
+
 
 def assert_forward_equal(m1: torch.nn.Module, m2: torch.nn.Module, data_gen_fn: Callable[[], dict],
                          output_transform_fn: Callable[[Any], dict]) -> None:
@@ -65,10 +69,14 @@ def check_lazy_init(entry: TestingEntry, seed: int = 42, verbose: bool = False,
     ctx = LazyInitContext()
     with ctx:
         deferred_model = model_fn()
+        copied_deferred_model = deepcopy(deferred_model)
     deferred_model = ctx.materialize(deferred_model, verbose=verbose)
+    copied_deferred_model = ctx.materialize(copied_deferred_model, verbose=verbose)
     assert_model_equal(model, deferred_model)
+    assert_model_equal(deferred_model, copied_deferred_model)
     if check_forward:
         assert_forward_equal(model, deferred_model, data_gen_fn, output_transform_fn)
+        assert_forward_equal(deferred_model, copied_deferred_model, data_gen_fn, output_transform_fn)
     if verbose:
         print(f'{model.__class__.__name__} pass')
 
diff --git a/tests/test_utils/test_lazy_init/test_distribute.py b/tests/test_lazy/test_distribute.py
similarity index 97%
rename from tests/test_utils/test_lazy_init/test_distribute.py
rename to tests/test_lazy/test_distribute.py
index fd91e7e91..d515b175a 100644
--- a/tests/test_utils/test_lazy_init/test_distribute.py
+++ b/tests/test_lazy/test_distribute.py
@@ -12,7 +12,7 @@ from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
 from colossalai.utils.common import print_rank_0
 
 try:
-    from colossalai.utils.model.experimental import LazyInitContext, LazyTensor, _MyTensor
+    from colossalai.lazy.lazy_init import LazyInitContext, LazyTensor, _MyTensor
 except:
     pass
 from lazy_init_utils import SUPPORT_LAZY, assert_dist_model_equal, set_seed
diff --git a/tests/test_utils/test_lazy_init/test_models.py b/tests/test_lazy/test_models.py
similarity index 100%
rename from tests/test_utils/test_lazy_init/test_models.py
rename to tests/test_lazy/test_models.py
diff --git a/tests/test_utils/test_lazy_init_ctx.py b/tests/test_utils/test_lazy_init_ctx.py
deleted file mode 100644
index 97efb3367..000000000
--- a/tests/test_utils/test_lazy_init_ctx.py
+++ /dev/null
@@ -1,51 +0,0 @@
-import torch
-from colossalai.utils.model.lazy_init_context import LazyInitContext
-from torchvision.models import resnet34
-import random
-import numpy as np
-
-MANUAL_SEED = 0
-random.seed(MANUAL_SEED)
-np.random.seed(MANUAL_SEED)
-torch.manual_seed(MANUAL_SEED)
-
-
-def test_lazy_init_with_meta():
-    ctx = LazyInitContext(to_meta=True)
-    with ctx:
-        model = resnet34(num_classes=10)
-
-    for param in model.parameters():
-        assert param.is_meta
-    for buffer in model.buffers():
-        assert buffer.is_meta
-
-    ctx.lazy_init_parameters(model)
-
-    for name, param in model.named_parameters():
-        assert not param.is_meta, name
-
-    for buffer in model.buffers():
-        assert not buffer.is_meta
-
-
-def test_lazy_init_without_meta():
-    ctx = LazyInitContext(to_meta=False)
-    with ctx:
-        model = resnet34(num_classes=10)
-
-    for param in model.parameters():
-        assert not param.is_meta
-    for buffer in model.buffers():
-        assert not buffer.is_meta
-
-    conv1_weight_before_init = model.conv1.weight.clone()
-    ctx.lazy_init_parameters(model)
-    conv1_weight_after_init = model.conv1.weight.clone()
-
-    assert not torch.allclose(conv1_weight_after_init, conv1_weight_before_init)
-
-
-if __name__ == '__main__':
-    test_lazy_init_with_meta()
-    test_lazy_init_without_meta()