From 250be4d31e7d1eb3b26107f2061c78d9ec673d2c Mon Sep 17 00:00:00 2001 From: Frank Lee Date: Fri, 15 Jul 2022 17:47:12 +0800 Subject: [PATCH] [utils] integrated colotensor with lazy init context (#1324) * [utils] integrated colotensor with lazy init context * polish code * polish code * polish code --- colossalai/utils/model/lazy_init_context.py | 176 +++++++++----------- tests/test_utils/test_lazy_init_ctx.py | 35 ++-- 2 files changed, 103 insertions(+), 108 deletions(-) diff --git a/colossalai/utils/model/lazy_init_context.py b/colossalai/utils/model/lazy_init_context.py index a72c59fee..142ced630 100644 --- a/colossalai/utils/model/lazy_init_context.py +++ b/colossalai/utils/model/lazy_init_context.py @@ -2,13 +2,13 @@ # coding: utf-8 import torch -from colossalai.tensor import ColoParameter +import torch.nn as nn +from colossalai.tensor import ColoParameter, ColoTensor + import types import inspect -import typing from typing import List, Callable from colossalai.utils.model.utils import substitute_init_recursively -import copy class LazyInitContext(): @@ -18,8 +18,7 @@ class LazyInitContext(): Note: This API is only experimental and subject to future changes. - It should be integrated with meta tensor initialization in the future. - + Usage: with LazyInitContext() as ctx: model = nn.Linear(10, 10) @@ -36,14 +35,17 @@ class LazyInitContext(): assert not model.weight.is_meta and torch.all(model.weight == 0) Args: + to_meta (bool): optional, whether to initialize the model with meta tensors, default is False. extra_torch_tensor_func (List[str]): extra torch tensor functions related to value setting, such as `zero_` and `triu_`. `zero_` is pre-added by default. """ - tensor_set_value_func = ['zero_'] + tensor_set_value_func = ['zero_', 'fill_'] - def __init__(self, extra_torch_tensor_func: List[str] = None): - self._intercepted_init_func_cache = [] + def __init__(self, to_meta: bool = False, extra_torch_tensor_func: List[str] = None): + # TODO: hijack the torch constructor functions as well + self._to_meta = to_meta + self._intercepted_nn_init_func_cache = {} self._nn_init_methods = self._get_nn_init_methods() self._torch_mod_cls = torch.nn.modules.module.Module @@ -53,14 +55,20 @@ class LazyInitContext(): else: self._torch_tensor_funcs = self.tensor_set_value_func - def _cache_func(self, func): + @property + def to_meta(self): + return self._to_meta + + def _cache_init_func(self, func): """ - This method wraps the ``torch.nn.init`` method so that the function call - is cached instead of being executed. + This method wraps the ``torch.nn.init`` method and torch tensor value-setting functions + so that the function call is cached instead of being executed. """ - def wrapped_init_func(*args, **kwargs): - self._intercepted_init_func_cache.append(dict(func=func, args=args, kwargs=kwargs)) + def wrapped_init_func(tensor, *args, **kwargs): + if tensor not in self._intercepted_nn_init_func_cache: + self._intercepted_nn_init_func_cache[tensor] = [] + self._intercepted_nn_init_func_cache[tensor].append((func, args, kwargs)) return wrapped_init_func @@ -76,17 +84,10 @@ class LazyInitContext(): for name in nn_init_method_names: nn_init_methods.append((name, getattr(torch.nn.init, name))) - def _has_tensor_in_arg(func): - hints = typing.get_type_hints(func) - for k, v in hints.items(): - if v is torch.Tensor: - return True - return False - def _is_init_method(item): name, func = item - if (not isinstance(func, types.FunctionType) or name.startswith('_') or not name.endswith('_') - or not _has_tensor_in_arg(func)): + + if (not isinstance(func, types.FunctionType) or name.startswith('_') or not name.endswith('_')): return False else: return True @@ -103,11 +104,13 @@ class LazyInitContext(): has_device = 'device' in inspect.signature(func).parameters def layer_lazy_init(module, *args, **kwargs): - self._intercepted_init_func_cache.append( - dict(func=func, module=module, args=args, kwargs=copy.deepcopy(kwargs))) + # if this module contains device argument + # we set it to meta to initialize as meta backend if has_device: kwargs['device'] = 'meta' func(module, *args, **kwargs) + + # if device is not found, we intialize it and convert to meta if not has_device: module.to('meta') @@ -122,7 +125,7 @@ class LazyInitContext(): def _patch_nn_init_funcs(self): # patch nn.init functions for name, func in self._nn_init_methods: - setattr(torch.nn.init, name, self._cache_func(func)) + setattr(torch.nn.init, name, self._cache_init_func(func)) def _unpatch_nn_init_funcs(self): # unpatch nn.init functions @@ -150,7 +153,7 @@ class LazyInitContext(): origin_func_name = self._get_tmp_origin_func_ref(func_name) origin_func = getattr(torch.Tensor, func_name) setattr(torch.Tensor, origin_func_name, origin_func) - setattr(torch.Tensor, func_name, self._cache_func(origin_func)) + setattr(torch.Tensor, func_name, self._cache_init_func(origin_func)) def _unpatch_torch_tensor_funcs(self): for func_name in self._torch_tensor_funcs: @@ -159,17 +162,18 @@ class LazyInitContext(): setattr(torch.Tensor, func_name, origin_func) def __enter__(self): - self._patch_submodule_init() + self._patch_torch_tensor_funcs() + self._patch_nn_init_funcs() + + if self._to_meta: + self._patch_submodule_init() return self def __exit__(self, *args, **kwargs): - self._unpatch_submodule_init() - # build model_rebuild_dict in reverse order to make sure get correct init func for inherited class. - self.module_rebuild_dict = {} - self._intercepted_init_func_cache.reverse() - for cache in self._intercepted_init_func_cache: - self.module_rebuild_dict[cache['module']] = (cache['func'], cache['args'], cache['kwargs']) - self._intercepted_init_func_cache.reverse() + if self._to_meta: + self._unpatch_submodule_init() + self._unpatch_nn_init_funcs() + self._unpatch_torch_tensor_funcs() def lazy_init_parameters(self, model: torch.nn.Module, device='cpu', call_back: Callable = None): """ @@ -178,80 +182,56 @@ class LazyInitContext(): Args: model (`torch.nn.Module`): the model instantiated under the context. device (str): the device on which weights are initialized - + """ - # build param mapping - param_id_to_name = dict() - for name, param in model.named_parameters(): - param_id_to_name[id(param)] = name - for name, buffer in model.named_buffers(): - param_id_to_name[id(buffer)] = name - assert model in self.module_rebuild_dict, 'We only support rebuild modules which intercepted during initializing by us.' + def _init_recursively(module: nn.Module): + # recursively initialize the module + for mod in module.children(): + _init_recursively(mod) - def _process_arg(arg): - """ - Process args recursively. If arg is a torch.nn.Module instance in module_rebuild_dict, - we need to rebuild it with real parameters. If arg is a tuple or list, we will process - the element of arg with this function again. - """ - if torch.is_tensor(arg): - tensor_id = id(arg) - if tensor_id in param_id_to_name: - arg = _replace_meta_param_with_real_param(arg) + # initialize and shard tensors directly attached to the current module + for name, param in module.named_parameters(recurse=False): + _init_and_shard(module, name, param) - elif isinstance(arg, torch.nn.Module): - if arg in self.module_rebuild_dict: - arg = self.lazy_init_parameters(model=arg, device=device, call_back=call_back) + for name, buf in module.named_buffers(recurse=False): + _init_and_shard(module, name, buf) - elif isinstance(arg, (tuple, list)): - rst_list = [] - for element in arg: - processed_element = _process_arg(element) - rst_list.append(processed_element) - arg = rst_list - return arg + @torch.no_grad() + def _init_and_shard(module, name, tensor): + # check whether the tensor is a buffer or parameter + is_param = isinstance(tensor, nn.parameter.Parameter) - def _replace_meta_param_with_real_param(meta_param): - if meta_param.device != 'meta': - return meta_param - tensor_id = id(meta_param) - param_full_name = param_id_to_name[tensor_id] - real_param = torch.empty_like(meta_param, dtype=meta_param.dtype, device=device) - real_param = ColoParameter(real_param, requires_grad=meta_param.requires_grad) + # get sharding spec + dist_spec = getattr(tensor, 'dist_spec', None) + pg = getattr(tensor, 'pg', None) - if '.' in param_full_name: - submodule_name, param_name = param_full_name.rsplit('.', 1) - submodule = model.get_submodule(submodule_name) + # convert the tensor from meta to materialized one + if tensor.is_meta: + materialized_tensor = torch.empty_like(tensor, device=device) + # if this tensor is a meta tensor, it must have an init function + assert tensor in self._intercepted_nn_init_func_cache + tensor = materialized_tensor + + # apply init function + if tensor in self._intercepted_nn_init_func_cache: + init_func, args, kwargs = self._intercepted_nn_init_func_cache[tensor][-1] + init_func(tensor, *args, **kwargs) + + # convert it to ColoTensor or ColoParameter + if is_param: + tensor = ColoParameter.from_torch_tensor(tensor, requires_grad=tensor.requires_grad) else: - submodule = model - param_name = param_full_name - setattr(submodule, param_name, real_param) + tensor = ColoTensor.from_torch_tensor(tensor) - # execute call_back function on the materailized tensor - # this can where sharding comes in - if call_back: - call_back(real_param) - return real_param + # apply sharding + if dist_spec: + tensor = tensor.redistribute(dist_spec=dist_spec, pg=pg) - func, args, kwargs = self.module_rebuild_dict[model] - args = list(args) + # override the original tensor + with torch.no_grad(): + setattr(module, name, tensor) - # check args for parameter replacement - for idx, arg in enumerate(args): - arg = _process_arg(arg) - args[idx] = arg - - # check kwargs for parameter replacement - for arg_name, arg in kwargs.items(): - if arg_name == 'device': - arg = device - else: - arg = _process_arg(arg) - kwargs[arg_name] = arg - - # build user specified model - with torch.no_grad(): - func(model, *args, **kwargs) + _init_recursively(model) return model diff --git a/tests/test_utils/test_lazy_init_ctx.py b/tests/test_utils/test_lazy_init_ctx.py index fccf1588b..97efb3367 100644 --- a/tests/test_utils/test_lazy_init_ctx.py +++ b/tests/test_utils/test_lazy_init_ctx.py @@ -10,27 +10,42 @@ np.random.seed(MANUAL_SEED) torch.manual_seed(MANUAL_SEED) -def test_lazy_init(): - cpu_rng_state = torch.get_rng_state() - origin_model = resnet34(num_classes=10) - origin_param_dict = dict(origin_model.named_parameters()) - torch.set_rng_state(cpu_rng_state) - ctx = LazyInitContext() +def test_lazy_init_with_meta(): + ctx = LazyInitContext(to_meta=True) with ctx: model = resnet34(num_classes=10) + for param in model.parameters(): assert param.is_meta for buffer in model.buffers(): assert buffer.is_meta + ctx.lazy_init_parameters(model) + + for name, param in model.named_parameters(): + assert not param.is_meta, name + + for buffer in model.buffers(): + assert not buffer.is_meta + + +def test_lazy_init_without_meta(): + ctx = LazyInitContext(to_meta=False) + with ctx: + model = resnet34(num_classes=10) + for param in model.parameters(): assert not param.is_meta for buffer in model.buffers(): assert not buffer.is_meta - param_dict = dict(model.named_parameters()) - for key in origin_param_dict.keys(): - assert origin_param_dict[key].data.equal(param_dict[key].data) + + conv1_weight_before_init = model.conv1.weight.clone() + ctx.lazy_init_parameters(model) + conv1_weight_after_init = model.conv1.weight.clone() + + assert not torch.allclose(conv1_weight_after_init, conv1_weight_before_init) if __name__ == '__main__': - test_lazy_init() + test_lazy_init_with_meta() + test_lazy_init_without_meta()