diff --git a/colossalai/lazy/__init__.py b/colossalai/lazy/__init__.py new file mode 100644 index 000000000..4387107bf --- /dev/null +++ b/colossalai/lazy/__init__.py @@ -0,0 +1,6 @@ +from .lazy_init import LazyInitContext, LazyTensor + +__all__ = [ + 'LazyInitContext', + 'LazyTensor', +] diff --git a/colossalai/utils/model/experimental.py b/colossalai/lazy/lazy_init.py similarity index 98% rename from colossalai/utils/model/experimental.py rename to colossalai/lazy/lazy_init.py index bf3e3d05b..c1fda3c53 100644 --- a/colossalai/utils/model/experimental.py +++ b/colossalai/lazy/lazy_init.py @@ -350,7 +350,14 @@ class LazyTensor(torch.Tensor): copied.requires_grad_() return copied - target = LazyTensor(factory_fn, meta_data=self._meta_data) + if self._materialized_data is not None: + # self is early materialized + copied = self._materialized_data.detach().clone() + if self.requires_grad: + copied.requires_grad_() + target = LazyTensor(lambda: None, concrete_data=copied) + else: + target = LazyTensor(factory_fn, meta_data=self._meta_data) memo[id(self)] = target return target diff --git a/colossalai/utils/model/lazy_init_context.py b/colossalai/utils/model/lazy_init_context.py deleted file mode 100644 index cf05f9660..000000000 --- a/colossalai/utils/model/lazy_init_context.py +++ /dev/null @@ -1,242 +0,0 @@ -#!/usr/bin/env python -# coding: utf-8 - -import inspect -import types -from typing import Callable, List - -import torch -import torch.nn as nn - -from colossalai.tensor import ColoParameter, ColoTensor -from colossalai.utils.model.utils import substitute_init_recursively - - -class LazyInitContext(): - """ - A context to allow for lazy weight initialization of PyTorch modules. It intercepts the tensor - initialization functions for lazy initialization - - Note: - This API is only experimental and subject to future changes. - - Usage: - with LazyInitContext() as ctx: - model = nn.Linear(10, 10) - model.weight.zero_() - - # make sure the weight is a meta tensor - assert model.weight.is_meta - - # initialize weights - ctx.lazy_init_parameters(model) - - # make sure the weight is not a meta tensor - # and initialized correctly - assert not model.weight.is_meta and torch.all(model.weight == 0) - - Args: - to_meta (bool): optional, whether to initialize the model with meta tensors, default is True. This - argument exists for now because some corner cases such as self.weight = torch.zeros(...) cannot be captured yet. - extra_torch_tensor_func (List[str]): extra torch tensor functions related - to value setting, such as `zero_` and `triu_`. `zero_` is pre-added by default. - """ - - tensor_set_value_func = ['zero_', 'fill_'] - - def __init__(self, to_meta: bool = True, extra_torch_tensor_func: List[str] = None): - # TODO: hijack the torch constructor functions as well - self._to_meta = to_meta - self._intercepted_nn_init_func_cache = {} - self._nn_init_methods = self._get_nn_init_methods() - self._torch_mod_cls = torch.nn.modules.module.Module - - if extra_torch_tensor_func: - # use tuple to remove duplicates - self._torch_tensor_funcs = tuple(self.tensor_set_value_func + extra_torch_tensor_func) - else: - self._torch_tensor_funcs = self.tensor_set_value_func - - @property - def to_meta(self): - return self._to_meta - - def _cache_init_func(self, func): - """ - This method wraps the ``torch.nn.init`` method and torch tensor value-setting functions - so that the function call is cached instead of being executed. - """ - - def wrapped_init_func(tensor, *args, **kwargs): - if tensor not in self._intercepted_nn_init_func_cache: - self._intercepted_nn_init_func_cache[tensor] = [] - self._intercepted_nn_init_func_cache[tensor].append((func, args, kwargs)) - - return wrapped_init_func - - def _get_nn_init_methods(self): - """ - This method looks for all available functions in the ``torch.nn.init`` - module. - """ - nn_init_method_names = dir(torch.nn.init) - nn_init_methods = [] - - # look for all methods in ``torch.nn.init`` module - for name in nn_init_method_names: - nn_init_methods.append((name, getattr(torch.nn.init, name))) - - def _is_init_method(item): - name, func = item - - if (not isinstance(func, types.FunctionType) or name.startswith('_') or not name.endswith('_')): - return False - else: - return True - - # remove methods which are not init functions - nn_init_methods = list(filter(_is_init_method, nn_init_methods)) - return nn_init_methods - - def _wrap_module_init(self, func): - """ - This method wraps the calls to the `__init__` of ``torch.nn.Module`` and replaces - the argument device with value 'meta' so that all modules are created as meta tensors. - """ - has_device = 'device' in inspect.signature(func).parameters - - def layer_lazy_init(module, *args, **kwargs): - # if this module contains device argument - # we set it to meta to initialize as meta backend - if has_device: - kwargs['device'] = 'meta' - func(module, *args, **kwargs) - - # if device is not found, we intialize it and convert to meta - if not has_device: - module.to('meta') - - return layer_lazy_init - - def _get_tmp_origin_func_ref(self, name): - """ - Generate a function name for consistency during caching and retrieving. - """ - return f'_orig_{name}' - - def _patch_nn_init_funcs(self): - # patch nn.init functions - for name, func in self._nn_init_methods: - setattr(torch.nn.init, name, self._cache_init_func(func)) - - def _unpatch_nn_init_funcs(self): - # unpatch nn.init functions - for name, func in self._nn_init_methods: - setattr(torch.nn.init, name, func) - - def _patch_submodule_init(self): - # patch classes __init__ methods - def _activate_wrap_init(cls): - cls.__orig_init__ = cls.__init__ - cls.__init__ = self._wrap_module_init(cls.__init__) - - substitute_init_recursively(self._torch_mod_cls, _activate_wrap_init, set()) - - def _unpatch_submodule_init(self): - - def _recover_orig_init(cls): - cls.__init__ = cls.__orig_init__ - - substitute_init_recursively(self._torch_mod_cls, _recover_orig_init, set()) - - def _patch_torch_tensor_funcs(self): - # patch tensor value-setting functions - for func_name in self._torch_tensor_funcs: - origin_func_name = self._get_tmp_origin_func_ref(func_name) - origin_func = getattr(torch.Tensor, func_name) - setattr(torch.Tensor, origin_func_name, origin_func) - setattr(torch.Tensor, func_name, self._cache_init_func(origin_func)) - - def _unpatch_torch_tensor_funcs(self): - for func_name in self._torch_tensor_funcs: - origin_func_name = self._get_tmp_origin_func_ref(func_name) - origin_func = getattr(torch.Tensor, origin_func_name) - setattr(torch.Tensor, func_name, origin_func) - - def __enter__(self): - self._patch_torch_tensor_funcs() - self._patch_nn_init_funcs() - - if self._to_meta: - self._patch_submodule_init() - return self - - def __exit__(self, *args, **kwargs): - if self._to_meta: - self._unpatch_submodule_init() - self._unpatch_nn_init_funcs() - self._unpatch_torch_tensor_funcs() - - def lazy_init_parameters(self, model: torch.nn.Module, device='cpu'): - """ - Initialize the weights of the meta-tensor model. - - Args: - model (`torch.nn.Module`): the model instantiated under the context. - device (str): the device on which weights are initialized - - """ - - def _init_recursively(module: nn.Module): - # recursively initialize the module - for mod in module.children(): - _init_recursively(mod) - - # initialize and shard tensors directly attached to the current module - for name, param in module.named_parameters(recurse=False): - _init_and_shard(module, name, param) - - for name, buf in module.named_buffers(recurse=False): - _init_and_shard(module, name, buf) - - @torch.no_grad() - def _init_and_shard(module, name, tensor): - # check whether the tensor is a buffer or parameter - is_param = isinstance(tensor, nn.parameter.Parameter) - - # get sharding spec - dist_spec = getattr(tensor, 'dist_spec', None) - pg = getattr(tensor, 'pg', None) - comp_spec = getattr(tensor, 'comp_spec', None) - - # convert the tensor from meta to materialized one - if tensor.is_meta: - materialized_tensor = torch.empty_like(tensor, device=device) - # if this tensor is a meta tensor, it must have an init function - assert tensor in self._intercepted_nn_init_func_cache - else: - materialized_tensor = tensor - - # apply init function - if tensor in self._intercepted_nn_init_func_cache: - init_func, args, kwargs = self._intercepted_nn_init_func_cache[tensor][-1] - init_func(materialized_tensor, *args, **kwargs) - - # convert it to ColoTensor or ColoParameter - if is_param: - tensor = ColoParameter.from_torch_tensor(materialized_tensor, requires_grad=tensor.requires_grad) - else: - tensor = ColoTensor.from_torch_tensor(materialized_tensor) - - # override the original tensor - with torch.no_grad(): - setattr(module, name, tensor) - - # apply sharding - if dist_spec: - tensor.process_group = pg - tensor.set_tensor_spec(dist_spec, comp_spec) - - _init_recursively(model) - - return model diff --git a/colossalai/zero/gemini/gemini_ddp.py b/colossalai/zero/gemini/gemini_ddp.py index 878c25be7..fd49362d6 100644 --- a/colossalai/zero/gemini/gemini_ddp.py +++ b/colossalai/zero/gemini/gemini_ddp.py @@ -2,13 +2,14 @@ import itertools from collections import OrderedDict from contextlib import nullcontext from functools import partial -from typing import Dict, Iterator, List, Optional, Union, Tuple, Set +from typing import Dict, Iterator, List, Optional, Set, Tuple, Union import torch import torch.distributed as dist import torch.nn as nn from colossalai.checkpoint_io.utils import calculate_tensor_size +from colossalai.lazy import LazyTensor from colossalai.logging import get_dist_logger from colossalai.nn.parallel.data_parallel import ColoDDP, _cast_float, free_storage from colossalai.tensor import ProcessGroup as ColoProcessGroup @@ -16,7 +17,6 @@ from colossalai.tensor import ReplicaSpec from colossalai.tensor.colo_parameter import ColoParameter, ColoTensor, ColoTensorSpec from colossalai.tensor.param_op_hook import ColoParamOpHookManager from colossalai.utils import get_current_device, is_ddp_ignored -from colossalai.utils.model.experimental import LazyTensor from .chunk import Chunk, ChunkManager, TensorState, init_chunk_manager from .gemini_hook import GeminiZeROHook @@ -96,34 +96,38 @@ class ZeroDDP(ColoDDP): param_name = m_name + '.' + p_name if m_name else p_name self.name2param[param_name] = p_var super().__init__(module, process_group=ColoProcessGroup()) - self._non_persistent_buffers_set=self._get_non_persistent_buffers_set(module) + self._non_persistent_buffers_set = self._get_non_persistent_buffers_set(module) self._cast_buffers() - def _get_non_persistent_buffers_set(self, module, memo: Optional[Set[nn.Module]] = None, prefix: str = '', remove_duplicate: bool = True): - - r""" - Args: - memo: a memo to store the set of modules already added to the result - prefix: a prefix that will be added to the name of the module - remove_duplicate: whether to remove the duplicated module instances in the result - or not - """ - - if memo is None: - memo = set() - self_non_persistent_set = set() - if module not in memo: - if remove_duplicate: - memo.add(module) - self_non_persistent_set = set(map(lambda key: prefix + ('.' if prefix else '') + key, module._non_persistent_buffers_set)) - for name, sub_module in module._modules.items(): - if sub_module is None: - continue - submodule_prefix = prefix + ('.' if prefix else '') + name - child_non_persistent_set = self._get_non_persistent_buffers_set(sub_module, memo, submodule_prefix, remove_duplicate) - self_non_persistent_set = set.union(self_non_persistent_set, child_non_persistent_set) - return self_non_persistent_set - + def _get_non_persistent_buffers_set(self, + module, + memo: Optional[Set[nn.Module]] = None, + prefix: str = '', + remove_duplicate: bool = True): + r""" + Args: + memo: a memo to store the set of modules already added to the result + prefix: a prefix that will be added to the name of the module + remove_duplicate: whether to remove the duplicated module instances in the result + or not + """ + + if memo is None: + memo = set() + self_non_persistent_set = set() + if module not in memo: + if remove_duplicate: + memo.add(module) + self_non_persistent_set = set( + map(lambda key: prefix + ('.' if prefix else '') + key, module._non_persistent_buffers_set)) + for name, sub_module in module._modules.items(): + if sub_module is None: + continue + submodule_prefix = prefix + ('.' if prefix else '') + name + child_non_persistent_set = self._get_non_persistent_buffers_set(sub_module, memo, submodule_prefix, + remove_duplicate) + self_non_persistent_set = set.union(self_non_persistent_set, child_non_persistent_set) + return self_non_persistent_set def _post_forward(self): """This function is only triggered for inference. diff --git a/tests/test_booster/test_plugin/test_gemini_plugin.py b/tests/test_booster/test_plugin/test_gemini_plugin.py index c7b3676fb..d606d6d89 100644 --- a/tests/test_booster/test_plugin/test_gemini_plugin.py +++ b/tests/test_booster/test_plugin/test_gemini_plugin.py @@ -8,10 +8,10 @@ import colossalai from colossalai.booster import Booster from colossalai.booster.plugin import GeminiPlugin from colossalai.fx import is_compatible_with_meta +from colossalai.lazy.lazy_init import LazyInitContext from colossalai.nn.optimizer import HybridAdam from colossalai.tensor.colo_parameter import ColoParameter from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn -from colossalai.utils.model.experimental import LazyInitContext from colossalai.zero import ColoInitContext from tests.kit.model_zoo import model_zoo diff --git a/tests/test_utils/test_lazy_init/lazy_init_utils.py b/tests/test_lazy/lazy_init_utils.py similarity index 85% rename from tests/test_utils/test_lazy_init/lazy_init_utils.py rename to tests/test_lazy/lazy_init_utils.py index aa87d32a8..85bfd0e27 100644 --- a/tests/test_utils/test_lazy_init/lazy_init_utils.py +++ b/tests/test_lazy/lazy_init_utils.py @@ -1,12 +1,13 @@ import random +from copy import deepcopy from typing import Any, Callable, Optional, Tuple import numpy as np import torch from packaging import version +from colossalai.lazy.lazy_init import LazyInitContext, LazyTensor, _MyTensor from colossalai.tensor.d_tensor.layout_converter import to_global -from colossalai.utils.model.experimental import LazyInitContext, LazyTensor, _MyTensor from tests.kit.model_zoo.registry import ModelAttribute SUPPORT_LAZY = version.parse(torch.__version__) >= version.parse('1.12.0') @@ -31,6 +32,9 @@ def assert_model_equal(m1: torch.nn.Module, m2: torch.nn.Module) -> None: assert n1 == n2 assert torch.equal(t1, t2), f'{n1} {t1} vs {t2}' + for p1, p2 in zip(m1.parameters(), m2.parameters()): + assert p1.requires_grad == p2.requires_grad + def assert_forward_equal(m1: torch.nn.Module, m2: torch.nn.Module, data_gen_fn: Callable[[], dict], output_transform_fn: Callable[[Any], dict]) -> None: @@ -65,10 +69,14 @@ def check_lazy_init(entry: TestingEntry, seed: int = 42, verbose: bool = False, ctx = LazyInitContext() with ctx: deferred_model = model_fn() + copied_deferred_model = deepcopy(deferred_model) deferred_model = ctx.materialize(deferred_model, verbose=verbose) + copied_deferred_model = ctx.materialize(copied_deferred_model, verbose=verbose) assert_model_equal(model, deferred_model) + assert_model_equal(deferred_model, copied_deferred_model) if check_forward: assert_forward_equal(model, deferred_model, data_gen_fn, output_transform_fn) + assert_forward_equal(deferred_model, copied_deferred_model, data_gen_fn, output_transform_fn) if verbose: print(f'{model.__class__.__name__} pass') diff --git a/tests/test_utils/test_lazy_init/test_distribute.py b/tests/test_lazy/test_distribute.py similarity index 97% rename from tests/test_utils/test_lazy_init/test_distribute.py rename to tests/test_lazy/test_distribute.py index fd91e7e91..d515b175a 100644 --- a/tests/test_utils/test_lazy_init/test_distribute.py +++ b/tests/test_lazy/test_distribute.py @@ -12,7 +12,7 @@ from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn from colossalai.utils.common import print_rank_0 try: - from colossalai.utils.model.experimental import LazyInitContext, LazyTensor, _MyTensor + from colossalai.lazy.lazy_init import LazyInitContext, LazyTensor, _MyTensor except: pass from lazy_init_utils import SUPPORT_LAZY, assert_dist_model_equal, set_seed diff --git a/tests/test_utils/test_lazy_init/test_models.py b/tests/test_lazy/test_models.py similarity index 100% rename from tests/test_utils/test_lazy_init/test_models.py rename to tests/test_lazy/test_models.py diff --git a/tests/test_utils/test_lazy_init_ctx.py b/tests/test_utils/test_lazy_init_ctx.py deleted file mode 100644 index 97efb3367..000000000 --- a/tests/test_utils/test_lazy_init_ctx.py +++ /dev/null @@ -1,51 +0,0 @@ -import torch -from colossalai.utils.model.lazy_init_context import LazyInitContext -from torchvision.models import resnet34 -import random -import numpy as np - -MANUAL_SEED = 0 -random.seed(MANUAL_SEED) -np.random.seed(MANUAL_SEED) -torch.manual_seed(MANUAL_SEED) - - -def test_lazy_init_with_meta(): - ctx = LazyInitContext(to_meta=True) - with ctx: - model = resnet34(num_classes=10) - - for param in model.parameters(): - assert param.is_meta - for buffer in model.buffers(): - assert buffer.is_meta - - ctx.lazy_init_parameters(model) - - for name, param in model.named_parameters(): - assert not param.is_meta, name - - for buffer in model.buffers(): - assert not buffer.is_meta - - -def test_lazy_init_without_meta(): - ctx = LazyInitContext(to_meta=False) - with ctx: - model = resnet34(num_classes=10) - - for param in model.parameters(): - assert not param.is_meta - for buffer in model.buffers(): - assert not buffer.is_meta - - conv1_weight_before_init = model.conv1.weight.clone() - ctx.lazy_init_parameters(model) - conv1_weight_after_init = model.conv1.weight.clone() - - assert not torch.allclose(conv1_weight_after_init, conv1_weight_before_init) - - -if __name__ == '__main__': - test_lazy_init_with_meta() - test_lazy_init_without_meta()