ColossalAI/colossalai/utils/model/lazy_init_context.py

224 lines
8.3 KiB
Python

#!/usr/bin/env python
# coding: utf-8
import torch
from colossalai.tensor import ColoParameter
import types
import inspect
import typing
from typing import List, Callable
class LazyInitContext():
"""
A context to allow for lazy weight initialization of PyTorch modules. It intercepts the tensor
initialization functions for lazy initialization
Note:
This API is only experimental and subject to future changes.
It should be integrated with meta tensor initialization in the future.
Usage:
with LazyInitContext() as ctx:
model = nn.Linear(10, 10)
model.weight.zero_()
# make sure the weight is a meta tensor
assert model.weight.is_meta
# initialize weights
ctx.lazy_init_parameters(model)
# make sure the weight is not a meta tensor
# and initialized correctly
assert not model.weight.is_meta and torch.all(model.weight == 0)
Args:
extra_torch_tensor_func (List[str]): extra torch tensor functions related
to value setting, such as `zero_` and `triu_`. `zero_` is pre-added by default.
"""
tensor_set_value_func = ['zero_']
def __init__(self, extra_torch_tensor_func: List[str] = None):
self._intercepted_init_func_cache = []
self._nn_init_methods = self._get_nn_init_methods()
self._torch_mod_cls = torch.nn.modules.module.Module
if extra_torch_tensor_func:
# use tuple to remove duplicates
self._torch_tensor_funcs = tuple(self.tensor_set_value_func + extra_torch_tensor_func)
else:
self._torch_tensor_funcs = self.tensor_set_value_func
def _cache_func(self, func):
"""
This method wraps the ``torch.nn.init`` method so that the function call
is cached instead of being executed.
"""
def wrapped_init_func(*args, **kwargs):
self._intercepted_init_func_cache.append(dict(func=func, args=args, kwargs=kwargs))
return wrapped_init_func
def _get_nn_init_methods(self):
"""
This method looks for all available functions in the ``torch.nn.init``
module.
"""
nn_init_method_names = dir(torch.nn.init)
nn_init_methods = []
# look for all methods in ``torch.nn.init`` module
for name in nn_init_method_names:
nn_init_methods.append((name, getattr(torch.nn.init, name)))
def _has_tensor_in_arg(func):
hints = typing.get_type_hints(torch.nn.init.normal_)
for k, v in hints.items():
if v is torch.Tensor:
return True
return False
def _is_init_method(item):
name, func = item
if (not isinstance(func, types.FunctionType) or
name.startswith('_') or
not name.endswith('_') or
not _has_tensor_in_arg(func)):
return False
else:
return True
# remove methods which are not init functions
nn_init_methods = list(filter(_is_init_method, nn_init_methods))
return nn_init_methods
def _wrap_module_init(self, func):
"""
This method wraps the calls to the `__init__` of ``torch.nn.Module`` and replaces
the argument device with value 'meta' so that all modules are created as meta tensors.
"""
has_device = 'device' in inspect.signature(func).parameters
def layer_lazy_init(*args, **kwargs):
if has_device:
kwargs['device'] = 'meta'
func(*args, **kwargs)
return layer_lazy_init
def _get_tmp_origin_func_ref(self, name):
"""
Generate a function name for consistency during caching and retrieving.
"""
return f'_orig_{name}'
def _patch_nn_init_funcs(self):
# patch nn.init functions
for name, func in self._nn_init_methods:
setattr(torch.nn.init, name, self._cache_func(func))
def _unpatch_nn_init_funcs(self):
# unpatch nn.init functions
for name, func in self._nn_init_methods:
setattr(torch.nn.init, name, func)
def _patch_submodule_init(self):
# patch classes __init__ methods
for sub_cls in self._torch_mod_cls.__subclasses__():
sub_cls.__orig_init__ = sub_cls.__init__
sub_cls.__init__ = self._wrap_module_init(sub_cls.__init__)
def _unpatch_submodule_init(self):
for sub_cls in self._torch_mod_cls.__subclasses__():
sub_cls.__init__ = sub_cls.__orig_init__
def _patch_torch_tensor_funcs(self):
# patch tensor value-setting functions
for func_name in self._torch_tensor_funcs:
origin_func_name = self._get_tmp_origin_func_ref(func_name)
origin_func = getattr(torch.Tensor, func_name)
setattr(torch.Tensor, origin_func_name, origin_func)
setattr(torch.Tensor, func_name, self._cache_func(origin_func))
def _unpatch_torch_tensor_funcs(self):
for func_name in self._torch_tensor_funcs:
origin_func_name = self._get_tmp_origin_func_ref(func_name)
origin_func = getattr(torch.Tensor, origin_func_name)
setattr(torch.Tensor, func_name, origin_func)
def __enter__(self):
self._patch_nn_init_funcs()
self._patch_torch_tensor_funcs()
self._patch_submodule_init()
return self
def __exit__(self, *args, **kwargs):
self._unpatch_submodule_init()
self._unpatch_torch_tensor_funcs()
self._unpatch_nn_init_funcs()
def lazy_init_parameters(self, model: torch.nn.Module, device='cpu', call_back: Callable = None):
"""
Initialize the weights of the meta-tensor model.
Args:
model (`torch.nn.Module`): the model instantiated under the context.
device (str): the device on which weights are initialized
"""
# build param mapping
param_id_to_name = dict()
for name, param in model.named_parameters():
param_id_to_name[id(param)] = name
def _replace_meta_param_with_real_param(meta_param):
tensor_id = id(meta_param)
param_full_name = param_id_to_name[tensor_id]
real_param = torch.empty_like(meta_param, dtype=meta_param.dtype, device=device)
real_param = ColoParameter(real_param, requires_grad=meta_param.requires_grad)
if '.' in param_full_name:
submodule_name, param_name = param_full_name.rsplit('.', 1)
submodule = model.get_submodule(submodule_name)
else:
submodule = model
param_name = param_full_name
setattr(submodule, param_name, real_param)
# execute call_back function on the materailized tensor
# this can where sharding comes in
if call_back:
call_back(real_param)
return real_param
# build modules
for cache in self._intercepted_init_func_cache:
func = cache['func']
args = list(cache['args'])
kwargs = cache['kwargs']
# check args for parameter replacement
for idx, arg in enumerate(args):
if torch.is_tensor(arg):
tensor_id = id(arg)
if tensor_id not in param_id_to_name:
continue
else:
arg = _replace_meta_param_with_real_param(arg)
args[idx] = arg
# check kwargs for parameter replacement
for arg_name, arg in enumerate(kwargs):
if torch.is_tensor(arg):
tensor_id = id(arg)
if tensor_id not in param_id_to_name:
continue
else:
arg = _replace_meta_param_with_real_param(arg)
kwargs[arg_name] = arg
with torch.no_grad():
func(*args, **kwargs)