mirror of https://github.com/hpcaitech/ColossalAI
[utils] integrated colotensor with lazy init context (#1324)
* [utils] integrated colotensor with lazy init context * polish code * polish code * polish codepull/1326/head
parent
659a740738
commit
250be4d31e
|
@ -2,13 +2,13 @@
|
|||
# coding: utf-8
|
||||
|
||||
import torch
|
||||
from colossalai.tensor import ColoParameter
|
||||
import torch.nn as nn
|
||||
from colossalai.tensor import ColoParameter, ColoTensor
|
||||
|
||||
import types
|
||||
import inspect
|
||||
import typing
|
||||
from typing import List, Callable
|
||||
from colossalai.utils.model.utils import substitute_init_recursively
|
||||
import copy
|
||||
|
||||
|
||||
class LazyInitContext():
|
||||
|
@ -18,8 +18,7 @@ class LazyInitContext():
|
|||
|
||||
Note:
|
||||
This API is only experimental and subject to future changes.
|
||||
It should be integrated with meta tensor initialization in the future.
|
||||
|
||||
|
||||
Usage:
|
||||
with LazyInitContext() as ctx:
|
||||
model = nn.Linear(10, 10)
|
||||
|
@ -36,14 +35,17 @@ class LazyInitContext():
|
|||
assert not model.weight.is_meta and torch.all(model.weight == 0)
|
||||
|
||||
Args:
|
||||
to_meta (bool): optional, whether to initialize the model with meta tensors, default is False.
|
||||
extra_torch_tensor_func (List[str]): extra torch tensor functions related
|
||||
to value setting, such as `zero_` and `triu_`. `zero_` is pre-added by default.
|
||||
"""
|
||||
|
||||
tensor_set_value_func = ['zero_']
|
||||
tensor_set_value_func = ['zero_', 'fill_']
|
||||
|
||||
def __init__(self, extra_torch_tensor_func: List[str] = None):
|
||||
self._intercepted_init_func_cache = []
|
||||
def __init__(self, to_meta: bool = False, extra_torch_tensor_func: List[str] = None):
|
||||
# TODO: hijack the torch constructor functions as well
|
||||
self._to_meta = to_meta
|
||||
self._intercepted_nn_init_func_cache = {}
|
||||
self._nn_init_methods = self._get_nn_init_methods()
|
||||
self._torch_mod_cls = torch.nn.modules.module.Module
|
||||
|
||||
|
@ -53,14 +55,20 @@ class LazyInitContext():
|
|||
else:
|
||||
self._torch_tensor_funcs = self.tensor_set_value_func
|
||||
|
||||
def _cache_func(self, func):
|
||||
@property
|
||||
def to_meta(self):
|
||||
return self._to_meta
|
||||
|
||||
def _cache_init_func(self, func):
|
||||
"""
|
||||
This method wraps the ``torch.nn.init`` method so that the function call
|
||||
is cached instead of being executed.
|
||||
This method wraps the ``torch.nn.init`` method and torch tensor value-setting functions
|
||||
so that the function call is cached instead of being executed.
|
||||
"""
|
||||
|
||||
def wrapped_init_func(*args, **kwargs):
|
||||
self._intercepted_init_func_cache.append(dict(func=func, args=args, kwargs=kwargs))
|
||||
def wrapped_init_func(tensor, *args, **kwargs):
|
||||
if tensor not in self._intercepted_nn_init_func_cache:
|
||||
self._intercepted_nn_init_func_cache[tensor] = []
|
||||
self._intercepted_nn_init_func_cache[tensor].append((func, args, kwargs))
|
||||
|
||||
return wrapped_init_func
|
||||
|
||||
|
@ -76,17 +84,10 @@ class LazyInitContext():
|
|||
for name in nn_init_method_names:
|
||||
nn_init_methods.append((name, getattr(torch.nn.init, name)))
|
||||
|
||||
def _has_tensor_in_arg(func):
|
||||
hints = typing.get_type_hints(func)
|
||||
for k, v in hints.items():
|
||||
if v is torch.Tensor:
|
||||
return True
|
||||
return False
|
||||
|
||||
def _is_init_method(item):
|
||||
name, func = item
|
||||
if (not isinstance(func, types.FunctionType) or name.startswith('_') or not name.endswith('_')
|
||||
or not _has_tensor_in_arg(func)):
|
||||
|
||||
if (not isinstance(func, types.FunctionType) or name.startswith('_') or not name.endswith('_')):
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
@ -103,11 +104,13 @@ class LazyInitContext():
|
|||
has_device = 'device' in inspect.signature(func).parameters
|
||||
|
||||
def layer_lazy_init(module, *args, **kwargs):
|
||||
self._intercepted_init_func_cache.append(
|
||||
dict(func=func, module=module, args=args, kwargs=copy.deepcopy(kwargs)))
|
||||
# if this module contains device argument
|
||||
# we set it to meta to initialize as meta backend
|
||||
if has_device:
|
||||
kwargs['device'] = 'meta'
|
||||
func(module, *args, **kwargs)
|
||||
|
||||
# if device is not found, we intialize it and convert to meta
|
||||
if not has_device:
|
||||
module.to('meta')
|
||||
|
||||
|
@ -122,7 +125,7 @@ class LazyInitContext():
|
|||
def _patch_nn_init_funcs(self):
|
||||
# patch nn.init functions
|
||||
for name, func in self._nn_init_methods:
|
||||
setattr(torch.nn.init, name, self._cache_func(func))
|
||||
setattr(torch.nn.init, name, self._cache_init_func(func))
|
||||
|
||||
def _unpatch_nn_init_funcs(self):
|
||||
# unpatch nn.init functions
|
||||
|
@ -150,7 +153,7 @@ class LazyInitContext():
|
|||
origin_func_name = self._get_tmp_origin_func_ref(func_name)
|
||||
origin_func = getattr(torch.Tensor, func_name)
|
||||
setattr(torch.Tensor, origin_func_name, origin_func)
|
||||
setattr(torch.Tensor, func_name, self._cache_func(origin_func))
|
||||
setattr(torch.Tensor, func_name, self._cache_init_func(origin_func))
|
||||
|
||||
def _unpatch_torch_tensor_funcs(self):
|
||||
for func_name in self._torch_tensor_funcs:
|
||||
|
@ -159,17 +162,18 @@ class LazyInitContext():
|
|||
setattr(torch.Tensor, func_name, origin_func)
|
||||
|
||||
def __enter__(self):
|
||||
self._patch_submodule_init()
|
||||
self._patch_torch_tensor_funcs()
|
||||
self._patch_nn_init_funcs()
|
||||
|
||||
if self._to_meta:
|
||||
self._patch_submodule_init()
|
||||
return self
|
||||
|
||||
def __exit__(self, *args, **kwargs):
|
||||
self._unpatch_submodule_init()
|
||||
# build model_rebuild_dict in reverse order to make sure get correct init func for inherited class.
|
||||
self.module_rebuild_dict = {}
|
||||
self._intercepted_init_func_cache.reverse()
|
||||
for cache in self._intercepted_init_func_cache:
|
||||
self.module_rebuild_dict[cache['module']] = (cache['func'], cache['args'], cache['kwargs'])
|
||||
self._intercepted_init_func_cache.reverse()
|
||||
if self._to_meta:
|
||||
self._unpatch_submodule_init()
|
||||
self._unpatch_nn_init_funcs()
|
||||
self._unpatch_torch_tensor_funcs()
|
||||
|
||||
def lazy_init_parameters(self, model: torch.nn.Module, device='cpu', call_back: Callable = None):
|
||||
"""
|
||||
|
@ -178,80 +182,56 @@ class LazyInitContext():
|
|||
Args:
|
||||
model (`torch.nn.Module`): the model instantiated under the context.
|
||||
device (str): the device on which weights are initialized
|
||||
|
||||
|
||||
"""
|
||||
# build param mapping
|
||||
param_id_to_name = dict()
|
||||
for name, param in model.named_parameters():
|
||||
param_id_to_name[id(param)] = name
|
||||
for name, buffer in model.named_buffers():
|
||||
param_id_to_name[id(buffer)] = name
|
||||
|
||||
assert model in self.module_rebuild_dict, 'We only support rebuild modules which intercepted during initializing by us.'
|
||||
def _init_recursively(module: nn.Module):
|
||||
# recursively initialize the module
|
||||
for mod in module.children():
|
||||
_init_recursively(mod)
|
||||
|
||||
def _process_arg(arg):
|
||||
"""
|
||||
Process args recursively. If arg is a torch.nn.Module instance in module_rebuild_dict,
|
||||
we need to rebuild it with real parameters. If arg is a tuple or list, we will process
|
||||
the element of arg with this function again.
|
||||
"""
|
||||
if torch.is_tensor(arg):
|
||||
tensor_id = id(arg)
|
||||
if tensor_id in param_id_to_name:
|
||||
arg = _replace_meta_param_with_real_param(arg)
|
||||
# initialize and shard tensors directly attached to the current module
|
||||
for name, param in module.named_parameters(recurse=False):
|
||||
_init_and_shard(module, name, param)
|
||||
|
||||
elif isinstance(arg, torch.nn.Module):
|
||||
if arg in self.module_rebuild_dict:
|
||||
arg = self.lazy_init_parameters(model=arg, device=device, call_back=call_back)
|
||||
for name, buf in module.named_buffers(recurse=False):
|
||||
_init_and_shard(module, name, buf)
|
||||
|
||||
elif isinstance(arg, (tuple, list)):
|
||||
rst_list = []
|
||||
for element in arg:
|
||||
processed_element = _process_arg(element)
|
||||
rst_list.append(processed_element)
|
||||
arg = rst_list
|
||||
return arg
|
||||
@torch.no_grad()
|
||||
def _init_and_shard(module, name, tensor):
|
||||
# check whether the tensor is a buffer or parameter
|
||||
is_param = isinstance(tensor, nn.parameter.Parameter)
|
||||
|
||||
def _replace_meta_param_with_real_param(meta_param):
|
||||
if meta_param.device != 'meta':
|
||||
return meta_param
|
||||
tensor_id = id(meta_param)
|
||||
param_full_name = param_id_to_name[tensor_id]
|
||||
real_param = torch.empty_like(meta_param, dtype=meta_param.dtype, device=device)
|
||||
real_param = ColoParameter(real_param, requires_grad=meta_param.requires_grad)
|
||||
# get sharding spec
|
||||
dist_spec = getattr(tensor, 'dist_spec', None)
|
||||
pg = getattr(tensor, 'pg', None)
|
||||
|
||||
if '.' in param_full_name:
|
||||
submodule_name, param_name = param_full_name.rsplit('.', 1)
|
||||
submodule = model.get_submodule(submodule_name)
|
||||
# convert the tensor from meta to materialized one
|
||||
if tensor.is_meta:
|
||||
materialized_tensor = torch.empty_like(tensor, device=device)
|
||||
# if this tensor is a meta tensor, it must have an init function
|
||||
assert tensor in self._intercepted_nn_init_func_cache
|
||||
tensor = materialized_tensor
|
||||
|
||||
# apply init function
|
||||
if tensor in self._intercepted_nn_init_func_cache:
|
||||
init_func, args, kwargs = self._intercepted_nn_init_func_cache[tensor][-1]
|
||||
init_func(tensor, *args, **kwargs)
|
||||
|
||||
# convert it to ColoTensor or ColoParameter
|
||||
if is_param:
|
||||
tensor = ColoParameter.from_torch_tensor(tensor, requires_grad=tensor.requires_grad)
|
||||
else:
|
||||
submodule = model
|
||||
param_name = param_full_name
|
||||
setattr(submodule, param_name, real_param)
|
||||
tensor = ColoTensor.from_torch_tensor(tensor)
|
||||
|
||||
# execute call_back function on the materailized tensor
|
||||
# this can where sharding comes in
|
||||
if call_back:
|
||||
call_back(real_param)
|
||||
return real_param
|
||||
# apply sharding
|
||||
if dist_spec:
|
||||
tensor = tensor.redistribute(dist_spec=dist_spec, pg=pg)
|
||||
|
||||
func, args, kwargs = self.module_rebuild_dict[model]
|
||||
args = list(args)
|
||||
# override the original tensor
|
||||
with torch.no_grad():
|
||||
setattr(module, name, tensor)
|
||||
|
||||
# check args for parameter replacement
|
||||
for idx, arg in enumerate(args):
|
||||
arg = _process_arg(arg)
|
||||
args[idx] = arg
|
||||
|
||||
# check kwargs for parameter replacement
|
||||
for arg_name, arg in kwargs.items():
|
||||
if arg_name == 'device':
|
||||
arg = device
|
||||
else:
|
||||
arg = _process_arg(arg)
|
||||
kwargs[arg_name] = arg
|
||||
|
||||
# build user specified model
|
||||
with torch.no_grad():
|
||||
func(model, *args, **kwargs)
|
||||
_init_recursively(model)
|
||||
|
||||
return model
|
||||
|
|
|
@ -10,27 +10,42 @@ np.random.seed(MANUAL_SEED)
|
|||
torch.manual_seed(MANUAL_SEED)
|
||||
|
||||
|
||||
def test_lazy_init():
|
||||
cpu_rng_state = torch.get_rng_state()
|
||||
origin_model = resnet34(num_classes=10)
|
||||
origin_param_dict = dict(origin_model.named_parameters())
|
||||
torch.set_rng_state(cpu_rng_state)
|
||||
ctx = LazyInitContext()
|
||||
def test_lazy_init_with_meta():
|
||||
ctx = LazyInitContext(to_meta=True)
|
||||
with ctx:
|
||||
model = resnet34(num_classes=10)
|
||||
|
||||
for param in model.parameters():
|
||||
assert param.is_meta
|
||||
for buffer in model.buffers():
|
||||
assert buffer.is_meta
|
||||
|
||||
ctx.lazy_init_parameters(model)
|
||||
|
||||
for name, param in model.named_parameters():
|
||||
assert not param.is_meta, name
|
||||
|
||||
for buffer in model.buffers():
|
||||
assert not buffer.is_meta
|
||||
|
||||
|
||||
def test_lazy_init_without_meta():
|
||||
ctx = LazyInitContext(to_meta=False)
|
||||
with ctx:
|
||||
model = resnet34(num_classes=10)
|
||||
|
||||
for param in model.parameters():
|
||||
assert not param.is_meta
|
||||
for buffer in model.buffers():
|
||||
assert not buffer.is_meta
|
||||
param_dict = dict(model.named_parameters())
|
||||
for key in origin_param_dict.keys():
|
||||
assert origin_param_dict[key].data.equal(param_dict[key].data)
|
||||
|
||||
conv1_weight_before_init = model.conv1.weight.clone()
|
||||
ctx.lazy_init_parameters(model)
|
||||
conv1_weight_after_init = model.conv1.weight.clone()
|
||||
|
||||
assert not torch.allclose(conv1_weight_after_init, conv1_weight_before_init)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_lazy_init()
|
||||
test_lazy_init_with_meta()
|
||||
test_lazy_init_without_meta()
|
||||
|
|
Loading…
Reference in New Issue