mirror of https://github.com/hpcaitech/ColossalAI
[context]support arbitary module materialization. (#1193)
* [CLI] add CLI launcher
* Revert "[CLI] add CLI launcher"
This reverts commit df7e6506d4
.
* [context]support arbitary module materialization.
* [test]add numerical check for lazy init context.
pull/1197/head
parent
a444633d13
commit
63d2a93878
|
@ -8,6 +8,7 @@ import inspect
|
|||
import typing
|
||||
from typing import List, Callable
|
||||
from colossalai.utils.model.utils import substitute_init_recursively
|
||||
import copy
|
||||
|
||||
|
||||
class LazyInitContext():
|
||||
|
@ -102,7 +103,8 @@ class LazyInitContext():
|
|||
has_device = 'device' in inspect.signature(func).parameters
|
||||
|
||||
def layer_lazy_init(module, *args, **kwargs):
|
||||
self._intercepted_init_func_cache.append(dict(func=func, module=module, args=args, kwargs=kwargs))
|
||||
self._intercepted_init_func_cache.append(
|
||||
dict(func=func, module=module, args=args, kwargs=copy.deepcopy(kwargs)))
|
||||
if has_device:
|
||||
kwargs['device'] = 'meta'
|
||||
func(module, *args, **kwargs)
|
||||
|
@ -162,6 +164,12 @@ class LazyInitContext():
|
|||
|
||||
def __exit__(self, *args, **kwargs):
|
||||
self._unpatch_submodule_init()
|
||||
# build model_rebuild_dict in reverse order to make sure get correct init func for inherited class.
|
||||
self.module_rebuild_dict = {}
|
||||
self._intercepted_init_func_cache.reverse()
|
||||
for cache in self._intercepted_init_func_cache:
|
||||
self.module_rebuild_dict[cache['module']] = (cache['func'], cache['args'], cache['kwargs'])
|
||||
self._intercepted_init_func_cache.reverse()
|
||||
|
||||
def lazy_init_parameters(self, model: torch.nn.Module, device='cpu', call_back: Callable = None):
|
||||
"""
|
||||
|
@ -179,7 +187,34 @@ class LazyInitContext():
|
|||
for name, buffer in model.named_buffers():
|
||||
param_id_to_name[id(buffer)] = name
|
||||
|
||||
assert model in self.module_rebuild_dict, 'We only support rebuild modules which intercepted during initializing by us.'
|
||||
|
||||
def _process_arg(arg):
|
||||
"""
|
||||
Process args recursively. If arg is a torch.nn.Module instance in module_rebuild_dict,
|
||||
we need to rebuild it with real parameters. If arg is a tuple or list, we will process
|
||||
the element of arg with this function again.
|
||||
"""
|
||||
if torch.is_tensor(arg):
|
||||
tensor_id = id(arg)
|
||||
if tensor_id in param_id_to_name:
|
||||
arg = _replace_meta_param_with_real_param(arg)
|
||||
|
||||
elif isinstance(arg, torch.nn.Module):
|
||||
if arg in self.module_rebuild_dict:
|
||||
arg = self.lazy_init_parameters(model=arg, device=device, call_back=call_back)
|
||||
|
||||
elif isinstance(arg, (tuple, list)):
|
||||
rst_list = []
|
||||
for element in arg:
|
||||
processed_element = _process_arg(element)
|
||||
rst_list.append(processed_element)
|
||||
arg = rst_list
|
||||
return arg
|
||||
|
||||
def _replace_meta_param_with_real_param(meta_param):
|
||||
if meta_param.device != 'meta':
|
||||
return meta_param
|
||||
tensor_id = id(meta_param)
|
||||
param_full_name = param_id_to_name[tensor_id]
|
||||
real_param = torch.empty_like(meta_param, dtype=meta_param.dtype, device=device)
|
||||
|
@ -199,36 +234,24 @@ class LazyInitContext():
|
|||
call_back(real_param)
|
||||
return real_param
|
||||
|
||||
# build modules
|
||||
# visit the cache list in reverse order
|
||||
for index in range(len(self._intercepted_init_func_cache)):
|
||||
cache = self._intercepted_init_func_cache[len(self._intercepted_init_func_cache) - index - 1]
|
||||
func = cache['func']
|
||||
module = cache['module']
|
||||
args = list(cache['args'])
|
||||
kwargs = cache['kwargs']
|
||||
func, args, kwargs = self.module_rebuild_dict[model]
|
||||
args = list(args)
|
||||
|
||||
# check args for parameter replacement
|
||||
for idx, arg in enumerate(args):
|
||||
if torch.is_tensor(arg):
|
||||
tensor_id = id(arg)
|
||||
# check args for parameter replacement
|
||||
for idx, arg in enumerate(args):
|
||||
arg = _process_arg(arg)
|
||||
args[idx] = arg
|
||||
|
||||
if tensor_id not in param_id_to_name:
|
||||
continue
|
||||
else:
|
||||
arg = _replace_meta_param_with_real_param(arg)
|
||||
args[idx] = arg
|
||||
# check kwargs for parameter replacement
|
||||
for arg_name, arg in kwargs.items():
|
||||
if arg_name == 'device':
|
||||
arg = device
|
||||
else:
|
||||
arg = _process_arg(arg)
|
||||
kwargs[arg_name] = arg
|
||||
|
||||
# check kwargs for parameter replacement
|
||||
for arg_name, arg in enumerate(kwargs):
|
||||
if torch.is_tensor(arg):
|
||||
tensor_id = id(arg)
|
||||
# build user specified model
|
||||
with torch.no_grad():
|
||||
func(model, *args, **kwargs)
|
||||
|
||||
if tensor_id not in param_id_to_name:
|
||||
continue
|
||||
else:
|
||||
arg = _replace_meta_param_with_real_param(arg)
|
||||
kwargs[arg_name] = arg
|
||||
|
||||
with torch.no_grad():
|
||||
func(module, *args, **kwargs)
|
||||
return model
|
||||
|
|
|
@ -1,9 +1,20 @@
|
|||
import torch
|
||||
from colossalai.utils.model.lazy_init_context import LazyInitContext
|
||||
from torchvision.models import resnet34
|
||||
import random
|
||||
import numpy as np
|
||||
|
||||
MANUAL_SEED = 0
|
||||
random.seed(MANUAL_SEED)
|
||||
np.random.seed(MANUAL_SEED)
|
||||
torch.manual_seed(MANUAL_SEED)
|
||||
|
||||
|
||||
def test_lazy_init():
|
||||
cpu_rng_state = torch.get_rng_state()
|
||||
origin_model = resnet34(num_classes=10)
|
||||
origin_param_dict = dict(origin_model.named_parameters())
|
||||
torch.set_rng_state(cpu_rng_state)
|
||||
ctx = LazyInitContext()
|
||||
with ctx:
|
||||
model = resnet34(num_classes=10)
|
||||
|
@ -16,6 +27,9 @@ def test_lazy_init():
|
|||
assert not param.is_meta
|
||||
for buffer in model.buffers():
|
||||
assert not buffer.is_meta
|
||||
param_dict = dict(model.named_parameters())
|
||||
for key in origin_param_dict.keys():
|
||||
assert origin_param_dict[key].data.equal(param_dict[key].data)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
@ -0,0 +1,55 @@
|
|||
import torch
|
||||
from colossalai.utils.model.lazy_init_context import LazyInitContext
|
||||
from torchvision.models import resnet34
|
||||
import random
|
||||
import numpy as np
|
||||
|
||||
MANUAL_SEED = 0
|
||||
random.seed(MANUAL_SEED)
|
||||
np.random.seed(MANUAL_SEED)
|
||||
torch.manual_seed(MANUAL_SEED)
|
||||
|
||||
|
||||
class MLP(torch.nn.Module):
|
||||
|
||||
def __init__(self, dim: int = 4):
|
||||
super().__init__()
|
||||
intermediate_dim = dim * 4
|
||||
self.dense_1 = torch.nn.Linear(dim, intermediate_dim)
|
||||
self.activation = torch.nn.GELU()
|
||||
self.dense_2 = torch.nn.Linear(intermediate_dim, dim)
|
||||
self.dropout = torch.nn.Dropout(0.1)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.dense_1(x)
|
||||
x = self.activation(x)
|
||||
x = self.dense_2(x)
|
||||
x = self.dropout(x)
|
||||
return x
|
||||
|
||||
|
||||
def test_lazy_init():
|
||||
cpu_rng_state = torch.get_rng_state()
|
||||
origin_model = MLP()
|
||||
origin_param_dict = dict(origin_model.named_parameters())
|
||||
torch.set_rng_state(cpu_rng_state)
|
||||
ctx = LazyInitContext()
|
||||
with ctx:
|
||||
model = MLP()
|
||||
for param in model.parameters():
|
||||
assert param.is_meta
|
||||
for buffer in model.buffers():
|
||||
assert buffer.is_meta
|
||||
for module in model.children():
|
||||
ctx.lazy_init_parameters(module)
|
||||
for param in module.parameters():
|
||||
assert not param.is_meta
|
||||
for buffer in module.buffers():
|
||||
assert not buffer.is_meta
|
||||
param_dict = dict(model.named_parameters())
|
||||
for key in origin_param_dict.keys():
|
||||
assert origin_param_dict[key].data.equal(param_dict[key].data)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_lazy_init()
|
Loading…
Reference in New Issue