mirror of https://github.com/hpcaitech/ColossalAI
[context]use meta tensor to init model lazily. (#1187)
* [CLI] add CLI launcher
* Revert "[CLI] add CLI launcher"
This reverts commit df7e6506d4
.
* [context]use meta tensor to init model lazily.
* polish
* make module with device kwargs bypass the normal init.
* change unit test to adapt updated context.
pull/1191/head
parent
2c8c05675d
commit
2053e138a2
|
@ -7,6 +7,8 @@ import types
|
|||
import inspect
|
||||
import typing
|
||||
from typing import List, Callable
|
||||
from colossalai.utils.model.utils import substitute_init_recursively
|
||||
|
||||
|
||||
class LazyInitContext():
|
||||
"""
|
||||
|
@ -55,8 +57,10 @@ class LazyInitContext():
|
|||
This method wraps the ``torch.nn.init`` method so that the function call
|
||||
is cached instead of being executed.
|
||||
"""
|
||||
|
||||
def wrapped_init_func(*args, **kwargs):
|
||||
self._intercepted_init_func_cache.append(dict(func=func, args=args, kwargs=kwargs))
|
||||
|
||||
return wrapped_init_func
|
||||
|
||||
def _get_nn_init_methods(self):
|
||||
|
@ -72,7 +76,7 @@ class LazyInitContext():
|
|||
nn_init_methods.append((name, getattr(torch.nn.init, name)))
|
||||
|
||||
def _has_tensor_in_arg(func):
|
||||
hints = typing.get_type_hints(torch.nn.init.normal_)
|
||||
hints = typing.get_type_hints(func)
|
||||
for k, v in hints.items():
|
||||
if v is torch.Tensor:
|
||||
return True
|
||||
|
@ -80,10 +84,8 @@ class LazyInitContext():
|
|||
|
||||
def _is_init_method(item):
|
||||
name, func = item
|
||||
if (not isinstance(func, types.FunctionType) or
|
||||
name.startswith('_') or
|
||||
not name.endswith('_') or
|
||||
not _has_tensor_in_arg(func)):
|
||||
if (not isinstance(func, types.FunctionType) or name.startswith('_') or not name.endswith('_')
|
||||
or not _has_tensor_in_arg(func)):
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
@ -99,10 +101,14 @@ class LazyInitContext():
|
|||
"""
|
||||
has_device = 'device' in inspect.signature(func).parameters
|
||||
|
||||
def layer_lazy_init(*args, **kwargs):
|
||||
def layer_lazy_init(module, *args, **kwargs):
|
||||
self._intercepted_init_func_cache.append(dict(func=func, module=module, args=args, kwargs=kwargs))
|
||||
if has_device:
|
||||
kwargs['device'] = 'meta'
|
||||
func(*args, **kwargs)
|
||||
func(module, *args, **kwargs)
|
||||
if not has_device:
|
||||
module.to('meta')
|
||||
|
||||
return layer_lazy_init
|
||||
|
||||
def _get_tmp_origin_func_ref(self, name):
|
||||
|
@ -123,13 +129,18 @@ class LazyInitContext():
|
|||
|
||||
def _patch_submodule_init(self):
|
||||
# patch classes __init__ methods
|
||||
for sub_cls in self._torch_mod_cls.__subclasses__():
|
||||
sub_cls.__orig_init__ = sub_cls.__init__
|
||||
sub_cls.__init__ = self._wrap_module_init(sub_cls.__init__)
|
||||
def _activate_wrap_init(cls):
|
||||
cls.__orig_init__ = cls.__init__
|
||||
cls.__init__ = self._wrap_module_init(cls.__init__)
|
||||
|
||||
substitute_init_recursively(self._torch_mod_cls, _activate_wrap_init)
|
||||
|
||||
def _unpatch_submodule_init(self):
|
||||
for sub_cls in self._torch_mod_cls.__subclasses__():
|
||||
sub_cls.__init__ = sub_cls.__orig_init__
|
||||
|
||||
def _recover_orig_init(cls):
|
||||
cls.__init__ = cls.__orig_init__
|
||||
|
||||
substitute_init_recursively(self._torch_mod_cls, _recover_orig_init)
|
||||
|
||||
def _patch_torch_tensor_funcs(self):
|
||||
# patch tensor value-setting functions
|
||||
|
@ -146,15 +157,11 @@ class LazyInitContext():
|
|||
setattr(torch.Tensor, func_name, origin_func)
|
||||
|
||||
def __enter__(self):
|
||||
self._patch_nn_init_funcs()
|
||||
self._patch_torch_tensor_funcs()
|
||||
self._patch_submodule_init()
|
||||
return self
|
||||
|
||||
def __exit__(self, *args, **kwargs):
|
||||
self._unpatch_submodule_init()
|
||||
self._unpatch_torch_tensor_funcs()
|
||||
self._unpatch_nn_init_funcs()
|
||||
|
||||
def lazy_init_parameters(self, model: torch.nn.Module, device='cpu', call_back: Callable = None):
|
||||
"""
|
||||
|
@ -169,6 +176,8 @@ class LazyInitContext():
|
|||
param_id_to_name = dict()
|
||||
for name, param in model.named_parameters():
|
||||
param_id_to_name[id(param)] = name
|
||||
for name, buffer in model.named_buffers():
|
||||
param_id_to_name[id(buffer)] = name
|
||||
|
||||
def _replace_meta_param_with_real_param(meta_param):
|
||||
tensor_id = id(meta_param)
|
||||
|
@ -190,10 +199,12 @@ class LazyInitContext():
|
|||
call_back(real_param)
|
||||
return real_param
|
||||
|
||||
|
||||
# build modules
|
||||
for cache in self._intercepted_init_func_cache:
|
||||
# visit the cache list in reverse order
|
||||
for index in range(len(self._intercepted_init_func_cache)):
|
||||
cache = self._intercepted_init_func_cache[len(self._intercepted_init_func_cache) - index - 1]
|
||||
func = cache['func']
|
||||
module = cache['module']
|
||||
args = list(cache['args'])
|
||||
kwargs = cache['kwargs']
|
||||
|
||||
|
@ -220,4 +231,4 @@ class LazyInitContext():
|
|||
kwargs[arg_name] = arg
|
||||
|
||||
with torch.no_grad():
|
||||
func(*args, **kwargs)
|
||||
func(module, *args, **kwargs)
|
||||
|
|
|
@ -3,9 +3,9 @@ import functools
|
|||
from typing import Optional
|
||||
|
||||
|
||||
def _substitute_init_recursively(cls, func):
|
||||
def substitute_init_recursively(cls, func):
|
||||
for subcls in cls.__subclasses__():
|
||||
_substitute_init_recursively(subcls, func)
|
||||
substitute_init_recursively(subcls, func)
|
||||
func(subcls)
|
||||
|
||||
|
||||
|
@ -64,7 +64,7 @@ class InsertPostInitMethodToModuleSubClasses(object):
|
|||
|
||||
# Replace .__init__() for all existing subclasses of torch.nn.Module
|
||||
# Excution self._post_init_method after the default init function.
|
||||
_substitute_init_recursively(torch.nn.modules.module.Module, _enable_class)
|
||||
substitute_init_recursively(torch.nn.modules.module.Module, _enable_class)
|
||||
|
||||
# holding on to the current __init__subclass__ for exit
|
||||
torch.nn.modules.module.Module._old_init_subclass = (torch.nn.modules.module.Module.__init_subclass__)
|
||||
|
@ -87,7 +87,7 @@ class InsertPostInitMethodToModuleSubClasses(object):
|
|||
cls.__init__ = cls._old_init
|
||||
|
||||
# Replace .__init__() for all existing subclasses of torch.nn.Module
|
||||
_substitute_init_recursively(torch.nn.modules.module.Module, _disable_class)
|
||||
substitute_init_recursively(torch.nn.modules.module.Module, _disable_class)
|
||||
|
||||
# Replace .__init__() for future subclasses of torch.nn.Module
|
||||
torch.nn.modules.module.Module.__init_subclass__ = (torch.nn.modules.module.Module._old_init_subclass)
|
||||
|
|
|
@ -1,23 +1,22 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
from colossalai.utils.model.lazy_init_context import LazyInitContext
|
||||
from torchvision.models import resnet34
|
||||
|
||||
def test_lazy_init_ctx():
|
||||
|
||||
with LazyInitContext() as ctx:
|
||||
model = nn.Linear(10, 10)
|
||||
model.weight.zero_()
|
||||
|
||||
# make sure the weight is a meta tensor
|
||||
assert model.weight.is_meta
|
||||
|
||||
# initialize weights
|
||||
def test_lazy_init():
|
||||
ctx = LazyInitContext()
|
||||
with ctx:
|
||||
model = resnet34(num_classes=10)
|
||||
for param in model.parameters():
|
||||
assert param.is_meta
|
||||
for buffer in model.buffers():
|
||||
assert buffer.is_meta
|
||||
ctx.lazy_init_parameters(model)
|
||||
|
||||
# make sure the weight is not a meta tensor
|
||||
# and initialized correctly
|
||||
assert not model.weight.is_meta and torch.all(model.weight == 0)
|
||||
for param in model.parameters():
|
||||
assert not param.is_meta
|
||||
for buffer in model.buffers():
|
||||
assert not buffer.is_meta
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_lazy_init_ctx()
|
||||
test_lazy_init()
|
||||
|
|
Loading…
Reference in New Issue