diff --git a/colossalai/utils/model/lazy_init_context.py b/colossalai/utils/model/lazy_init_context.py index 057cbd8f5..ed94429d4 100644 --- a/colossalai/utils/model/lazy_init_context.py +++ b/colossalai/utils/model/lazy_init_context.py @@ -15,7 +15,7 @@ class LazyInitContext(): """ A context to allow for lazy weight initialization of PyTorch modules. It intercepts the tensor initialization functions for lazy initialization - + Note: This API is only experimental and subject to future changes. @@ -23,17 +23,17 @@ class LazyInitContext(): with LazyInitContext() as ctx: model = nn.Linear(10, 10) model.weight.zero_() - + # make sure the weight is a meta tensor assert model.weight.is_meta - + # initialize weights ctx.lazy_init_parameters(model) - + # make sure the weight is not a meta tensor # and initialized correctly assert not model.weight.is_meta and torch.all(model.weight == 0) - + Args: to_meta (bool): optional, whether to initialize the model with meta tensors, default is False. extra_torch_tensor_func (List[str]): extra torch tensor functions related @@ -138,14 +138,14 @@ class LazyInitContext(): cls.__orig_init__ = cls.__init__ cls.__init__ = self._wrap_module_init(cls.__init__) - substitute_init_recursively(self._torch_mod_cls, _activate_wrap_init) + substitute_init_recursively(self._torch_mod_cls, _activate_wrap_init, set()) def _unpatch_submodule_init(self): def _recover_orig_init(cls): cls.__init__ = cls.__orig_init__ - substitute_init_recursively(self._torch_mod_cls, _recover_orig_init) + substitute_init_recursively(self._torch_mod_cls, _recover_orig_init, set()) def _patch_torch_tensor_funcs(self): # patch tensor value-setting functions @@ -178,7 +178,7 @@ class LazyInitContext(): def lazy_init_parameters(self, model: torch.nn.Module, device='cpu'): """ Initialize the weights of the meta-tensor model. - + Args: model (`torch.nn.Module`): the model instantiated under the context. device (str): the device on which weights are initialized diff --git a/colossalai/utils/model/utils.py b/colossalai/utils/model/utils.py index e1587b04f..75bb18df6 100644 --- a/colossalai/utils/model/utils.py +++ b/colossalai/utils/model/utils.py @@ -3,10 +3,12 @@ import functools from typing import Optional -def substitute_init_recursively(cls, func): +def substitute_init_recursively(cls, func, visited: set): for subcls in cls.__subclasses__(): - substitute_init_recursively(subcls, func) - func(subcls) + substitute_init_recursively(subcls, func, visited) + if subcls not in visited: + func(subcls) + visited.add(subcls) def call_to_str(base, *args, **kwargs): @@ -64,7 +66,7 @@ class InsertPostInitMethodToModuleSubClasses(object): # Replace .__init__() for all existing subclasses of torch.nn.Module # Excution self._post_init_method after the default init function. - substitute_init_recursively(torch.nn.modules.module.Module, _enable_class) + substitute_init_recursively(torch.nn.modules.module.Module, _enable_class, set()) # holding on to the current __init__subclass__ for exit torch.nn.modules.module.Module._old_init_subclass = (torch.nn.modules.module.Module.__init_subclass__) @@ -87,7 +89,7 @@ class InsertPostInitMethodToModuleSubClasses(object): cls.__init__ = cls._old_init # Replace .__init__() for all existing subclasses of torch.nn.Module - substitute_init_recursively(torch.nn.modules.module.Module, _disable_class) + substitute_init_recursively(torch.nn.modules.module.Module, _disable_class, set()) # Replace .__init__() for future subclasses of torch.nn.Module torch.nn.modules.module.Module.__init_subclass__ = (torch.nn.modules.module.Module._old_init_subclass)