From 2053e138a21959dc042d621c1b056c9269d9b6bf Mon Sep 17 00:00:00 2001 From: YuliangLiu0306 <72588413+YuliangLiu0306@users.noreply.github.com> Date: Wed, 29 Jun 2022 21:02:30 +0800 Subject: [PATCH] [context]use meta tensor to init model lazily. (#1187) * [CLI] add CLI launcher * Revert "[CLI] add CLI launcher" This reverts commit df7e6506d4500af6a9220ef7fe4d3c7b1daebd4c. * [context]use meta tensor to init model lazily. * polish * make module with device kwargs bypass the normal init. * change unit test to adapt updated context. --- colossalai/utils/model/lazy_init_context.py | 107 +++++++++++--------- colossalai/utils/model/utils.py | 8 +- tests/test_utils/test_lazy_init_ctx.py | 29 +++--- 3 files changed, 77 insertions(+), 67 deletions(-) diff --git a/colossalai/utils/model/lazy_init_context.py b/colossalai/utils/model/lazy_init_context.py index 147a957ad..290ab7aac 100644 --- a/colossalai/utils/model/lazy_init_context.py +++ b/colossalai/utils/model/lazy_init_context.py @@ -7,6 +7,8 @@ import types import inspect import typing from typing import List, Callable +from colossalai.utils.model.utils import substitute_init_recursively + class LazyInitContext(): """ @@ -36,29 +38,31 @@ class LazyInitContext(): extra_torch_tensor_func (List[str]): extra torch tensor functions related to value setting, such as `zero_` and `triu_`. `zero_` is pre-added by default. """ - + tensor_set_value_func = ['zero_'] - + def __init__(self, extra_torch_tensor_func: List[str] = None): self._intercepted_init_func_cache = [] self._nn_init_methods = self._get_nn_init_methods() self._torch_mod_cls = torch.nn.modules.module.Module - + if extra_torch_tensor_func: # use tuple to remove duplicates self._torch_tensor_funcs = tuple(self.tensor_set_value_func + extra_torch_tensor_func) else: self._torch_tensor_funcs = self.tensor_set_value_func - + def _cache_func(self, func): """ This method wraps the ``torch.nn.init`` method so that the function call is cached instead of being executed. """ + def wrapped_init_func(*args, **kwargs): self._intercepted_init_func_cache.append(dict(func=func, args=args, kwargs=kwargs)) + return wrapped_init_func - + def _get_nn_init_methods(self): """ This method looks for all available functions in the ``torch.nn.init`` @@ -66,32 +70,30 @@ class LazyInitContext(): """ nn_init_method_names = dir(torch.nn.init) nn_init_methods = [] - + # look for all methods in ``torch.nn.init`` module for name in nn_init_method_names: nn_init_methods.append((name, getattr(torch.nn.init, name))) - + def _has_tensor_in_arg(func): - hints = typing.get_type_hints(torch.nn.init.normal_) + hints = typing.get_type_hints(func) for k, v in hints.items(): if v is torch.Tensor: return True return False - + def _is_init_method(item): name, func = item - if (not isinstance(func, types.FunctionType) or - name.startswith('_') or - not name.endswith('_') or - not _has_tensor_in_arg(func)): + if (not isinstance(func, types.FunctionType) or name.startswith('_') or not name.endswith('_') + or not _has_tensor_in_arg(func)): return False else: return True - + # remove methods which are not init functions nn_init_methods = list(filter(_is_init_method, nn_init_methods)) return nn_init_methods - + def _wrap_module_init(self, func): """ This method wraps the calls to the `__init__` of ``torch.nn.Module`` and replaces @@ -99,38 +101,47 @@ class LazyInitContext(): """ has_device = 'device' in inspect.signature(func).parameters - def layer_lazy_init(*args, **kwargs): + def layer_lazy_init(module, *args, **kwargs): + self._intercepted_init_func_cache.append(dict(func=func, module=module, args=args, kwargs=kwargs)) if has_device: kwargs['device'] = 'meta' - func(*args, **kwargs) + func(module, *args, **kwargs) + if not has_device: + module.to('meta') + return layer_lazy_init - + def _get_tmp_origin_func_ref(self, name): """ Generate a function name for consistency during caching and retrieving. """ return f'_orig_{name}' - + def _patch_nn_init_funcs(self): # patch nn.init functions for name, func in self._nn_init_methods: setattr(torch.nn.init, name, self._cache_func(func)) - + def _unpatch_nn_init_funcs(self): # unpatch nn.init functions for name, func in self._nn_init_methods: setattr(torch.nn.init, name, func) - + def _patch_submodule_init(self): # patch classes __init__ methods - for sub_cls in self._torch_mod_cls.__subclasses__(): - sub_cls.__orig_init__ = sub_cls.__init__ - sub_cls.__init__ = self._wrap_module_init(sub_cls.__init__) - + def _activate_wrap_init(cls): + cls.__orig_init__ = cls.__init__ + cls.__init__ = self._wrap_module_init(cls.__init__) + + substitute_init_recursively(self._torch_mod_cls, _activate_wrap_init) + def _unpatch_submodule_init(self): - for sub_cls in self._torch_mod_cls.__subclasses__(): - sub_cls.__init__ = sub_cls.__orig_init__ - + + def _recover_orig_init(cls): + cls.__init__ = cls.__orig_init__ + + substitute_init_recursively(self._torch_mod_cls, _recover_orig_init) + def _patch_torch_tensor_funcs(self): # patch tensor value-setting functions for func_name in self._torch_tensor_funcs: @@ -138,24 +149,20 @@ class LazyInitContext(): origin_func = getattr(torch.Tensor, func_name) setattr(torch.Tensor, origin_func_name, origin_func) setattr(torch.Tensor, func_name, self._cache_func(origin_func)) - + def _unpatch_torch_tensor_funcs(self): for func_name in self._torch_tensor_funcs: origin_func_name = self._get_tmp_origin_func_ref(func_name) origin_func = getattr(torch.Tensor, origin_func_name) setattr(torch.Tensor, func_name, origin_func) - + def __enter__(self): - self._patch_nn_init_funcs() - self._patch_torch_tensor_funcs() self._patch_submodule_init() return self - + def __exit__(self, *args, **kwargs): self._unpatch_submodule_init() - self._unpatch_torch_tensor_funcs() - self._unpatch_nn_init_funcs() - + def lazy_init_parameters(self, model: torch.nn.Module, device='cpu', call_back: Callable = None): """ Initialize the weights of the meta-tensor model. @@ -169,13 +176,15 @@ class LazyInitContext(): param_id_to_name = dict() for name, param in model.named_parameters(): param_id_to_name[id(param)] = name - + for name, buffer in model.named_buffers(): + param_id_to_name[id(buffer)] = name + def _replace_meta_param_with_real_param(meta_param): tensor_id = id(meta_param) param_full_name = param_id_to_name[tensor_id] real_param = torch.empty_like(meta_param, dtype=meta_param.dtype, device=device) real_param = ColoParameter(real_param, requires_grad=meta_param.requires_grad) - + if '.' in param_full_name: submodule_name, param_name = param_full_name.rsplit('.', 1) submodule = model.get_submodule(submodule_name) @@ -183,41 +192,43 @@ class LazyInitContext(): submodule = model param_name = param_full_name setattr(submodule, param_name, real_param) - + # execute call_back function on the materailized tensor # this can where sharding comes in if call_back: call_back(real_param) return real_param - - + # build modules - for cache in self._intercepted_init_func_cache: + # visit the cache list in reverse order + for index in range(len(self._intercepted_init_func_cache)): + cache = self._intercepted_init_func_cache[len(self._intercepted_init_func_cache) - index - 1] func = cache['func'] + module = cache['module'] args = list(cache['args']) kwargs = cache['kwargs'] - + # check args for parameter replacement for idx, arg in enumerate(args): if torch.is_tensor(arg): tensor_id = id(arg) - + if tensor_id not in param_id_to_name: continue else: arg = _replace_meta_param_with_real_param(arg) args[idx] = arg - + # check kwargs for parameter replacement for arg_name, arg in enumerate(kwargs): if torch.is_tensor(arg): tensor_id = id(arg) - + if tensor_id not in param_id_to_name: continue else: arg = _replace_meta_param_with_real_param(arg) kwargs[arg_name] = arg - + with torch.no_grad(): - func(*args, **kwargs) + func(module, *args, **kwargs) diff --git a/colossalai/utils/model/utils.py b/colossalai/utils/model/utils.py index ecc0cdb5a..e1587b04f 100644 --- a/colossalai/utils/model/utils.py +++ b/colossalai/utils/model/utils.py @@ -3,9 +3,9 @@ import functools from typing import Optional -def _substitute_init_recursively(cls, func): +def substitute_init_recursively(cls, func): for subcls in cls.__subclasses__(): - _substitute_init_recursively(subcls, func) + substitute_init_recursively(subcls, func) func(subcls) @@ -64,7 +64,7 @@ class InsertPostInitMethodToModuleSubClasses(object): # Replace .__init__() for all existing subclasses of torch.nn.Module # Excution self._post_init_method after the default init function. - _substitute_init_recursively(torch.nn.modules.module.Module, _enable_class) + substitute_init_recursively(torch.nn.modules.module.Module, _enable_class) # holding on to the current __init__subclass__ for exit torch.nn.modules.module.Module._old_init_subclass = (torch.nn.modules.module.Module.__init_subclass__) @@ -87,7 +87,7 @@ class InsertPostInitMethodToModuleSubClasses(object): cls.__init__ = cls._old_init # Replace .__init__() for all existing subclasses of torch.nn.Module - _substitute_init_recursively(torch.nn.modules.module.Module, _disable_class) + substitute_init_recursively(torch.nn.modules.module.Module, _disable_class) # Replace .__init__() for future subclasses of torch.nn.Module torch.nn.modules.module.Module.__init_subclass__ = (torch.nn.modules.module.Module._old_init_subclass) diff --git a/tests/test_utils/test_lazy_init_ctx.py b/tests/test_utils/test_lazy_init_ctx.py index c391fd37b..4d4c0598c 100644 --- a/tests/test_utils/test_lazy_init_ctx.py +++ b/tests/test_utils/test_lazy_init_ctx.py @@ -1,23 +1,22 @@ import torch -import torch.nn as nn from colossalai.utils.model.lazy_init_context import LazyInitContext +from torchvision.models import resnet34 -def test_lazy_init_ctx(): - with LazyInitContext() as ctx: - model = nn.Linear(10, 10) - model.weight.zero_() - - # make sure the weight is a meta tensor - assert model.weight.is_meta - - # initialize weights +def test_lazy_init(): + ctx = LazyInitContext() + with ctx: + model = resnet34(num_classes=10) + for param in model.parameters(): + assert param.is_meta + for buffer in model.buffers(): + assert buffer.is_meta ctx.lazy_init_parameters(model) - - # make sure the weight is not a meta tensor - # and initialized correctly - assert not model.weight.is_meta and torch.all(model.weight == 0) + for param in model.parameters(): + assert not param.is_meta + for buffer in model.buffers(): + assert not buffer.is_meta if __name__ == '__main__': - test_lazy_init_ctx() + test_lazy_init()