[utils] integrated colotensor with lazy init context (#1324)

* [utils] integrated colotensor with lazy init context

* polish code

* polish code

* polish code
pull/1326/head
Frank Lee 2022-07-15 17:47:12 +08:00 committed by GitHub
parent 659a740738
commit 250be4d31e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 103 additions and 108 deletions

View File

@ -2,13 +2,13 @@
# coding: utf-8 # coding: utf-8
import torch import torch
from colossalai.tensor import ColoParameter import torch.nn as nn
from colossalai.tensor import ColoParameter, ColoTensor
import types import types
import inspect import inspect
import typing
from typing import List, Callable from typing import List, Callable
from colossalai.utils.model.utils import substitute_init_recursively from colossalai.utils.model.utils import substitute_init_recursively
import copy
class LazyInitContext(): class LazyInitContext():
@ -18,7 +18,6 @@ class LazyInitContext():
Note: Note:
This API is only experimental and subject to future changes. This API is only experimental and subject to future changes.
It should be integrated with meta tensor initialization in the future.
Usage: Usage:
with LazyInitContext() as ctx: with LazyInitContext() as ctx:
@ -36,14 +35,17 @@ class LazyInitContext():
assert not model.weight.is_meta and torch.all(model.weight == 0) assert not model.weight.is_meta and torch.all(model.weight == 0)
Args: Args:
to_meta (bool): optional, whether to initialize the model with meta tensors, default is False.
extra_torch_tensor_func (List[str]): extra torch tensor functions related extra_torch_tensor_func (List[str]): extra torch tensor functions related
to value setting, such as `zero_` and `triu_`. `zero_` is pre-added by default. to value setting, such as `zero_` and `triu_`. `zero_` is pre-added by default.
""" """
tensor_set_value_func = ['zero_'] tensor_set_value_func = ['zero_', 'fill_']
def __init__(self, extra_torch_tensor_func: List[str] = None): def __init__(self, to_meta: bool = False, extra_torch_tensor_func: List[str] = None):
self._intercepted_init_func_cache = [] # TODO: hijack the torch constructor functions as well
self._to_meta = to_meta
self._intercepted_nn_init_func_cache = {}
self._nn_init_methods = self._get_nn_init_methods() self._nn_init_methods = self._get_nn_init_methods()
self._torch_mod_cls = torch.nn.modules.module.Module self._torch_mod_cls = torch.nn.modules.module.Module
@ -53,14 +55,20 @@ class LazyInitContext():
else: else:
self._torch_tensor_funcs = self.tensor_set_value_func self._torch_tensor_funcs = self.tensor_set_value_func
def _cache_func(self, func): @property
def to_meta(self):
return self._to_meta
def _cache_init_func(self, func):
""" """
This method wraps the ``torch.nn.init`` method so that the function call This method wraps the ``torch.nn.init`` method and torch tensor value-setting functions
is cached instead of being executed. so that the function call is cached instead of being executed.
""" """
def wrapped_init_func(*args, **kwargs): def wrapped_init_func(tensor, *args, **kwargs):
self._intercepted_init_func_cache.append(dict(func=func, args=args, kwargs=kwargs)) if tensor not in self._intercepted_nn_init_func_cache:
self._intercepted_nn_init_func_cache[tensor] = []
self._intercepted_nn_init_func_cache[tensor].append((func, args, kwargs))
return wrapped_init_func return wrapped_init_func
@ -76,17 +84,10 @@ class LazyInitContext():
for name in nn_init_method_names: for name in nn_init_method_names:
nn_init_methods.append((name, getattr(torch.nn.init, name))) nn_init_methods.append((name, getattr(torch.nn.init, name)))
def _has_tensor_in_arg(func):
hints = typing.get_type_hints(func)
for k, v in hints.items():
if v is torch.Tensor:
return True
return False
def _is_init_method(item): def _is_init_method(item):
name, func = item name, func = item
if (not isinstance(func, types.FunctionType) or name.startswith('_') or not name.endswith('_')
or not _has_tensor_in_arg(func)): if (not isinstance(func, types.FunctionType) or name.startswith('_') or not name.endswith('_')):
return False return False
else: else:
return True return True
@ -103,11 +104,13 @@ class LazyInitContext():
has_device = 'device' in inspect.signature(func).parameters has_device = 'device' in inspect.signature(func).parameters
def layer_lazy_init(module, *args, **kwargs): def layer_lazy_init(module, *args, **kwargs):
self._intercepted_init_func_cache.append( # if this module contains device argument
dict(func=func, module=module, args=args, kwargs=copy.deepcopy(kwargs))) # we set it to meta to initialize as meta backend
if has_device: if has_device:
kwargs['device'] = 'meta' kwargs['device'] = 'meta'
func(module, *args, **kwargs) func(module, *args, **kwargs)
# if device is not found, we intialize it and convert to meta
if not has_device: if not has_device:
module.to('meta') module.to('meta')
@ -122,7 +125,7 @@ class LazyInitContext():
def _patch_nn_init_funcs(self): def _patch_nn_init_funcs(self):
# patch nn.init functions # patch nn.init functions
for name, func in self._nn_init_methods: for name, func in self._nn_init_methods:
setattr(torch.nn.init, name, self._cache_func(func)) setattr(torch.nn.init, name, self._cache_init_func(func))
def _unpatch_nn_init_funcs(self): def _unpatch_nn_init_funcs(self):
# unpatch nn.init functions # unpatch nn.init functions
@ -150,7 +153,7 @@ class LazyInitContext():
origin_func_name = self._get_tmp_origin_func_ref(func_name) origin_func_name = self._get_tmp_origin_func_ref(func_name)
origin_func = getattr(torch.Tensor, func_name) origin_func = getattr(torch.Tensor, func_name)
setattr(torch.Tensor, origin_func_name, origin_func) setattr(torch.Tensor, origin_func_name, origin_func)
setattr(torch.Tensor, func_name, self._cache_func(origin_func)) setattr(torch.Tensor, func_name, self._cache_init_func(origin_func))
def _unpatch_torch_tensor_funcs(self): def _unpatch_torch_tensor_funcs(self):
for func_name in self._torch_tensor_funcs: for func_name in self._torch_tensor_funcs:
@ -159,17 +162,18 @@ class LazyInitContext():
setattr(torch.Tensor, func_name, origin_func) setattr(torch.Tensor, func_name, origin_func)
def __enter__(self): def __enter__(self):
self._patch_torch_tensor_funcs()
self._patch_nn_init_funcs()
if self._to_meta:
self._patch_submodule_init() self._patch_submodule_init()
return self return self
def __exit__(self, *args, **kwargs): def __exit__(self, *args, **kwargs):
if self._to_meta:
self._unpatch_submodule_init() self._unpatch_submodule_init()
# build model_rebuild_dict in reverse order to make sure get correct init func for inherited class. self._unpatch_nn_init_funcs()
self.module_rebuild_dict = {} self._unpatch_torch_tensor_funcs()
self._intercepted_init_func_cache.reverse()
for cache in self._intercepted_init_func_cache:
self.module_rebuild_dict[cache['module']] = (cache['func'], cache['args'], cache['kwargs'])
self._intercepted_init_func_cache.reverse()
def lazy_init_parameters(self, model: torch.nn.Module, device='cpu', call_back: Callable = None): def lazy_init_parameters(self, model: torch.nn.Module, device='cpu', call_back: Callable = None):
""" """
@ -180,78 +184,54 @@ class LazyInitContext():
device (str): the device on which weights are initialized device (str): the device on which weights are initialized
""" """
# build param mapping
param_id_to_name = dict()
for name, param in model.named_parameters():
param_id_to_name[id(param)] = name
for name, buffer in model.named_buffers():
param_id_to_name[id(buffer)] = name
assert model in self.module_rebuild_dict, 'We only support rebuild modules which intercepted during initializing by us.' def _init_recursively(module: nn.Module):
# recursively initialize the module
for mod in module.children():
_init_recursively(mod)
def _process_arg(arg): # initialize and shard tensors directly attached to the current module
""" for name, param in module.named_parameters(recurse=False):
Process args recursively. If arg is a torch.nn.Module instance in module_rebuild_dict, _init_and_shard(module, name, param)
we need to rebuild it with real parameters. If arg is a tuple or list, we will process
the element of arg with this function again.
"""
if torch.is_tensor(arg):
tensor_id = id(arg)
if tensor_id in param_id_to_name:
arg = _replace_meta_param_with_real_param(arg)
elif isinstance(arg, torch.nn.Module): for name, buf in module.named_buffers(recurse=False):
if arg in self.module_rebuild_dict: _init_and_shard(module, name, buf)
arg = self.lazy_init_parameters(model=arg, device=device, call_back=call_back)
elif isinstance(arg, (tuple, list)): @torch.no_grad()
rst_list = [] def _init_and_shard(module, name, tensor):
for element in arg: # check whether the tensor is a buffer or parameter
processed_element = _process_arg(element) is_param = isinstance(tensor, nn.parameter.Parameter)
rst_list.append(processed_element)
arg = rst_list
return arg
def _replace_meta_param_with_real_param(meta_param): # get sharding spec
if meta_param.device != 'meta': dist_spec = getattr(tensor, 'dist_spec', None)
return meta_param pg = getattr(tensor, 'pg', None)
tensor_id = id(meta_param)
param_full_name = param_id_to_name[tensor_id]
real_param = torch.empty_like(meta_param, dtype=meta_param.dtype, device=device)
real_param = ColoParameter(real_param, requires_grad=meta_param.requires_grad)
if '.' in param_full_name: # convert the tensor from meta to materialized one
submodule_name, param_name = param_full_name.rsplit('.', 1) if tensor.is_meta:
submodule = model.get_submodule(submodule_name) materialized_tensor = torch.empty_like(tensor, device=device)
# if this tensor is a meta tensor, it must have an init function
assert tensor in self._intercepted_nn_init_func_cache
tensor = materialized_tensor
# apply init function
if tensor in self._intercepted_nn_init_func_cache:
init_func, args, kwargs = self._intercepted_nn_init_func_cache[tensor][-1]
init_func(tensor, *args, **kwargs)
# convert it to ColoTensor or ColoParameter
if is_param:
tensor = ColoParameter.from_torch_tensor(tensor, requires_grad=tensor.requires_grad)
else: else:
submodule = model tensor = ColoTensor.from_torch_tensor(tensor)
param_name = param_full_name
setattr(submodule, param_name, real_param)
# execute call_back function on the materailized tensor # apply sharding
# this can where sharding comes in if dist_spec:
if call_back: tensor = tensor.redistribute(dist_spec=dist_spec, pg=pg)
call_back(real_param)
return real_param
func, args, kwargs = self.module_rebuild_dict[model] # override the original tensor
args = list(args)
# check args for parameter replacement
for idx, arg in enumerate(args):
arg = _process_arg(arg)
args[idx] = arg
# check kwargs for parameter replacement
for arg_name, arg in kwargs.items():
if arg_name == 'device':
arg = device
else:
arg = _process_arg(arg)
kwargs[arg_name] = arg
# build user specified model
with torch.no_grad(): with torch.no_grad():
func(model, *args, **kwargs) setattr(module, name, tensor)
_init_recursively(model)
return model return model

View File

@ -10,27 +10,42 @@ np.random.seed(MANUAL_SEED)
torch.manual_seed(MANUAL_SEED) torch.manual_seed(MANUAL_SEED)
def test_lazy_init(): def test_lazy_init_with_meta():
cpu_rng_state = torch.get_rng_state() ctx = LazyInitContext(to_meta=True)
origin_model = resnet34(num_classes=10)
origin_param_dict = dict(origin_model.named_parameters())
torch.set_rng_state(cpu_rng_state)
ctx = LazyInitContext()
with ctx: with ctx:
model = resnet34(num_classes=10) model = resnet34(num_classes=10)
for param in model.parameters(): for param in model.parameters():
assert param.is_meta assert param.is_meta
for buffer in model.buffers(): for buffer in model.buffers():
assert buffer.is_meta assert buffer.is_meta
ctx.lazy_init_parameters(model) ctx.lazy_init_parameters(model)
for name, param in model.named_parameters():
assert not param.is_meta, name
for buffer in model.buffers():
assert not buffer.is_meta
def test_lazy_init_without_meta():
ctx = LazyInitContext(to_meta=False)
with ctx:
model = resnet34(num_classes=10)
for param in model.parameters(): for param in model.parameters():
assert not param.is_meta assert not param.is_meta
for buffer in model.buffers(): for buffer in model.buffers():
assert not buffer.is_meta assert not buffer.is_meta
param_dict = dict(model.named_parameters())
for key in origin_param_dict.keys(): conv1_weight_before_init = model.conv1.weight.clone()
assert origin_param_dict[key].data.equal(param_dict[key].data) ctx.lazy_init_parameters(model)
conv1_weight_after_init = model.conv1.weight.clone()
assert not torch.allclose(conv1_weight_after_init, conv1_weight_before_init)
if __name__ == '__main__': if __name__ == '__main__':
test_lazy_init() test_lazy_init_with_meta()
test_lazy_init_without_meta()