diff --git a/colossalai/utils/model/lazy_init_context.py b/colossalai/utils/model/lazy_init_context.py index 290ab7aac..a72c59fee 100644 --- a/colossalai/utils/model/lazy_init_context.py +++ b/colossalai/utils/model/lazy_init_context.py @@ -8,6 +8,7 @@ import inspect import typing from typing import List, Callable from colossalai.utils.model.utils import substitute_init_recursively +import copy class LazyInitContext(): @@ -102,7 +103,8 @@ class LazyInitContext(): has_device = 'device' in inspect.signature(func).parameters def layer_lazy_init(module, *args, **kwargs): - self._intercepted_init_func_cache.append(dict(func=func, module=module, args=args, kwargs=kwargs)) + self._intercepted_init_func_cache.append( + dict(func=func, module=module, args=args, kwargs=copy.deepcopy(kwargs))) if has_device: kwargs['device'] = 'meta' func(module, *args, **kwargs) @@ -162,6 +164,12 @@ class LazyInitContext(): def __exit__(self, *args, **kwargs): self._unpatch_submodule_init() + # build model_rebuild_dict in reverse order to make sure get correct init func for inherited class. + self.module_rebuild_dict = {} + self._intercepted_init_func_cache.reverse() + for cache in self._intercepted_init_func_cache: + self.module_rebuild_dict[cache['module']] = (cache['func'], cache['args'], cache['kwargs']) + self._intercepted_init_func_cache.reverse() def lazy_init_parameters(self, model: torch.nn.Module, device='cpu', call_back: Callable = None): """ @@ -179,7 +187,34 @@ class LazyInitContext(): for name, buffer in model.named_buffers(): param_id_to_name[id(buffer)] = name + assert model in self.module_rebuild_dict, 'We only support rebuild modules which intercepted during initializing by us.' + + def _process_arg(arg): + """ + Process args recursively. If arg is a torch.nn.Module instance in module_rebuild_dict, + we need to rebuild it with real parameters. If arg is a tuple or list, we will process + the element of arg with this function again. + """ + if torch.is_tensor(arg): + tensor_id = id(arg) + if tensor_id in param_id_to_name: + arg = _replace_meta_param_with_real_param(arg) + + elif isinstance(arg, torch.nn.Module): + if arg in self.module_rebuild_dict: + arg = self.lazy_init_parameters(model=arg, device=device, call_back=call_back) + + elif isinstance(arg, (tuple, list)): + rst_list = [] + for element in arg: + processed_element = _process_arg(element) + rst_list.append(processed_element) + arg = rst_list + return arg + def _replace_meta_param_with_real_param(meta_param): + if meta_param.device != 'meta': + return meta_param tensor_id = id(meta_param) param_full_name = param_id_to_name[tensor_id] real_param = torch.empty_like(meta_param, dtype=meta_param.dtype, device=device) @@ -199,36 +234,24 @@ class LazyInitContext(): call_back(real_param) return real_param - # build modules - # visit the cache list in reverse order - for index in range(len(self._intercepted_init_func_cache)): - cache = self._intercepted_init_func_cache[len(self._intercepted_init_func_cache) - index - 1] - func = cache['func'] - module = cache['module'] - args = list(cache['args']) - kwargs = cache['kwargs'] - - # check args for parameter replacement - for idx, arg in enumerate(args): - if torch.is_tensor(arg): - tensor_id = id(arg) - - if tensor_id not in param_id_to_name: - continue - else: - arg = _replace_meta_param_with_real_param(arg) - args[idx] = arg - - # check kwargs for parameter replacement - for arg_name, arg in enumerate(kwargs): - if torch.is_tensor(arg): - tensor_id = id(arg) - - if tensor_id not in param_id_to_name: - continue - else: - arg = _replace_meta_param_with_real_param(arg) - kwargs[arg_name] = arg - - with torch.no_grad(): - func(module, *args, **kwargs) + func, args, kwargs = self.module_rebuild_dict[model] + args = list(args) + + # check args for parameter replacement + for idx, arg in enumerate(args): + arg = _process_arg(arg) + args[idx] = arg + + # check kwargs for parameter replacement + for arg_name, arg in kwargs.items(): + if arg_name == 'device': + arg = device + else: + arg = _process_arg(arg) + kwargs[arg_name] = arg + + # build user specified model + with torch.no_grad(): + func(model, *args, **kwargs) + + return model diff --git a/tests/test_utils/test_lazy_init_ctx.py b/tests/test_utils/test_lazy_init_ctx.py index 4d4c0598c..fccf1588b 100644 --- a/tests/test_utils/test_lazy_init_ctx.py +++ b/tests/test_utils/test_lazy_init_ctx.py @@ -1,9 +1,20 @@ import torch from colossalai.utils.model.lazy_init_context import LazyInitContext from torchvision.models import resnet34 +import random +import numpy as np + +MANUAL_SEED = 0 +random.seed(MANUAL_SEED) +np.random.seed(MANUAL_SEED) +torch.manual_seed(MANUAL_SEED) def test_lazy_init(): + cpu_rng_state = torch.get_rng_state() + origin_model = resnet34(num_classes=10) + origin_param_dict = dict(origin_model.named_parameters()) + torch.set_rng_state(cpu_rng_state) ctx = LazyInitContext() with ctx: model = resnet34(num_classes=10) @@ -16,6 +27,9 @@ def test_lazy_init(): assert not param.is_meta for buffer in model.buffers(): assert not buffer.is_meta + param_dict = dict(model.named_parameters()) + for key in origin_param_dict.keys(): + assert origin_param_dict[key].data.equal(param_dict[key].data) if __name__ == '__main__': diff --git a/tests/test_utils/test_materialize_arbitary_lazy_module.py b/tests/test_utils/test_materialize_arbitary_lazy_module.py new file mode 100644 index 000000000..b84293490 --- /dev/null +++ b/tests/test_utils/test_materialize_arbitary_lazy_module.py @@ -0,0 +1,55 @@ +import torch +from colossalai.utils.model.lazy_init_context import LazyInitContext +from torchvision.models import resnet34 +import random +import numpy as np + +MANUAL_SEED = 0 +random.seed(MANUAL_SEED) +np.random.seed(MANUAL_SEED) +torch.manual_seed(MANUAL_SEED) + + +class MLP(torch.nn.Module): + + def __init__(self, dim: int = 4): + super().__init__() + intermediate_dim = dim * 4 + self.dense_1 = torch.nn.Linear(dim, intermediate_dim) + self.activation = torch.nn.GELU() + self.dense_2 = torch.nn.Linear(intermediate_dim, dim) + self.dropout = torch.nn.Dropout(0.1) + + def forward(self, x): + x = self.dense_1(x) + x = self.activation(x) + x = self.dense_2(x) + x = self.dropout(x) + return x + + +def test_lazy_init(): + cpu_rng_state = torch.get_rng_state() + origin_model = MLP() + origin_param_dict = dict(origin_model.named_parameters()) + torch.set_rng_state(cpu_rng_state) + ctx = LazyInitContext() + with ctx: + model = MLP() + for param in model.parameters(): + assert param.is_meta + for buffer in model.buffers(): + assert buffer.is_meta + for module in model.children(): + ctx.lazy_init_parameters(module) + for param in module.parameters(): + assert not param.is_meta + for buffer in module.buffers(): + assert not buffer.is_meta + param_dict = dict(model.named_parameters()) + for key in origin_param_dict.keys(): + assert origin_param_dict[key].data.equal(param_dict[key].data) + + +if __name__ == '__main__': + test_lazy_init()