Browse Source

[hotfix] fix init context (#1543)

* fix init context

* fix lazy init ctx
pull/1546/head
ver217 2 years ago committed by GitHub
parent
commit
a203b709d5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 16
      colossalai/utils/model/lazy_init_context.py
  2. 12
      colossalai/utils/model/utils.py

16
colossalai/utils/model/lazy_init_context.py

@ -15,7 +15,7 @@ class LazyInitContext():
"""
A context to allow for lazy weight initialization of PyTorch modules. It intercepts the tensor
initialization functions for lazy initialization
Note:
This API is only experimental and subject to future changes.
@ -23,17 +23,17 @@ class LazyInitContext():
with LazyInitContext() as ctx:
model = nn.Linear(10, 10)
model.weight.zero_()
# make sure the weight is a meta tensor
assert model.weight.is_meta
# initialize weights
ctx.lazy_init_parameters(model)
# make sure the weight is not a meta tensor
# and initialized correctly
assert not model.weight.is_meta and torch.all(model.weight == 0)
Args:
to_meta (bool): optional, whether to initialize the model with meta tensors, default is False.
extra_torch_tensor_func (List[str]): extra torch tensor functions related
@ -138,14 +138,14 @@ class LazyInitContext():
cls.__orig_init__ = cls.__init__
cls.__init__ = self._wrap_module_init(cls.__init__)
substitute_init_recursively(self._torch_mod_cls, _activate_wrap_init)
substitute_init_recursively(self._torch_mod_cls, _activate_wrap_init, set())
def _unpatch_submodule_init(self):
def _recover_orig_init(cls):
cls.__init__ = cls.__orig_init__
substitute_init_recursively(self._torch_mod_cls, _recover_orig_init)
substitute_init_recursively(self._torch_mod_cls, _recover_orig_init, set())
def _patch_torch_tensor_funcs(self):
# patch tensor value-setting functions
@ -178,7 +178,7 @@ class LazyInitContext():
def lazy_init_parameters(self, model: torch.nn.Module, device='cpu'):
"""
Initialize the weights of the meta-tensor model.
Args:
model (`torch.nn.Module`): the model instantiated under the context.
device (str): the device on which weights are initialized

12
colossalai/utils/model/utils.py

@ -3,10 +3,12 @@ import functools
from typing import Optional
def substitute_init_recursively(cls, func):
def substitute_init_recursively(cls, func, visited: set):
for subcls in cls.__subclasses__():
substitute_init_recursively(subcls, func)
func(subcls)
substitute_init_recursively(subcls, func, visited)
if subcls not in visited:
func(subcls)
visited.add(subcls)
def call_to_str(base, *args, **kwargs):
@ -64,7 +66,7 @@ class InsertPostInitMethodToModuleSubClasses(object):
# Replace .__init__() for all existing subclasses of torch.nn.Module
# Excution self._post_init_method after the default init function.
substitute_init_recursively(torch.nn.modules.module.Module, _enable_class)
substitute_init_recursively(torch.nn.modules.module.Module, _enable_class, set())
# holding on to the current __init__subclass__ for exit
torch.nn.modules.module.Module._old_init_subclass = (torch.nn.modules.module.Module.__init_subclass__)
@ -87,7 +89,7 @@ class InsertPostInitMethodToModuleSubClasses(object):
cls.__init__ = cls._old_init
# Replace .__init__() for all existing subclasses of torch.nn.Module
substitute_init_recursively(torch.nn.modules.module.Module, _disable_class)
substitute_init_recursively(torch.nn.modules.module.Module, _disable_class, set())
# Replace .__init__() for future subclasses of torch.nn.Module
torch.nn.modules.module.Module.__init_subclass__ = (torch.nn.modules.module.Module._old_init_subclass)

Loading…
Cancel
Save