[lazy] refactor lazy init (#3891)

* [lazy] remove old lazy init

* [lazy] refactor lazy init folder structure

* [lazy] fix lazy tensor deepcopy

* [test] update lazy init test
pull/3699/merge
Hongxin Liu 1 year ago committed by GitHub
parent 70c8cdecf4
commit dbb32692d2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -0,0 +1,6 @@
from .lazy_init import LazyInitContext, LazyTensor
__all__ = [
'LazyInitContext',
'LazyTensor',
]

@ -350,7 +350,14 @@ class LazyTensor(torch.Tensor):
copied.requires_grad_()
return copied
target = LazyTensor(factory_fn, meta_data=self._meta_data)
if self._materialized_data is not None:
# self is early materialized
copied = self._materialized_data.detach().clone()
if self.requires_grad:
copied.requires_grad_()
target = LazyTensor(lambda: None, concrete_data=copied)
else:
target = LazyTensor(factory_fn, meta_data=self._meta_data)
memo[id(self)] = target
return target

@ -1,242 +0,0 @@
#!/usr/bin/env python
# coding: utf-8
import inspect
import types
from typing import Callable, List
import torch
import torch.nn as nn
from colossalai.tensor import ColoParameter, ColoTensor
from colossalai.utils.model.utils import substitute_init_recursively
class LazyInitContext():
"""
A context to allow for lazy weight initialization of PyTorch modules. It intercepts the tensor
initialization functions for lazy initialization
Note:
This API is only experimental and subject to future changes.
Usage:
with LazyInitContext() as ctx:
model = nn.Linear(10, 10)
model.weight.zero_()
# make sure the weight is a meta tensor
assert model.weight.is_meta
# initialize weights
ctx.lazy_init_parameters(model)
# make sure the weight is not a meta tensor
# and initialized correctly
assert not model.weight.is_meta and torch.all(model.weight == 0)
Args:
to_meta (bool): optional, whether to initialize the model with meta tensors, default is True. This
argument exists for now because some corner cases such as self.weight = torch.zeros(...) cannot be captured yet.
extra_torch_tensor_func (List[str]): extra torch tensor functions related
to value setting, such as `zero_` and `triu_`. `zero_` is pre-added by default.
"""
tensor_set_value_func = ['zero_', 'fill_']
def __init__(self, to_meta: bool = True, extra_torch_tensor_func: List[str] = None):
# TODO: hijack the torch constructor functions as well
self._to_meta = to_meta
self._intercepted_nn_init_func_cache = {}
self._nn_init_methods = self._get_nn_init_methods()
self._torch_mod_cls = torch.nn.modules.module.Module
if extra_torch_tensor_func:
# use tuple to remove duplicates
self._torch_tensor_funcs = tuple(self.tensor_set_value_func + extra_torch_tensor_func)
else:
self._torch_tensor_funcs = self.tensor_set_value_func
@property
def to_meta(self):
return self._to_meta
def _cache_init_func(self, func):
"""
This method wraps the ``torch.nn.init`` method and torch tensor value-setting functions
so that the function call is cached instead of being executed.
"""
def wrapped_init_func(tensor, *args, **kwargs):
if tensor not in self._intercepted_nn_init_func_cache:
self._intercepted_nn_init_func_cache[tensor] = []
self._intercepted_nn_init_func_cache[tensor].append((func, args, kwargs))
return wrapped_init_func
def _get_nn_init_methods(self):
"""
This method looks for all available functions in the ``torch.nn.init``
module.
"""
nn_init_method_names = dir(torch.nn.init)
nn_init_methods = []
# look for all methods in ``torch.nn.init`` module
for name in nn_init_method_names:
nn_init_methods.append((name, getattr(torch.nn.init, name)))
def _is_init_method(item):
name, func = item
if (not isinstance(func, types.FunctionType) or name.startswith('_') or not name.endswith('_')):
return False
else:
return True
# remove methods which are not init functions
nn_init_methods = list(filter(_is_init_method, nn_init_methods))
return nn_init_methods
def _wrap_module_init(self, func):
"""
This method wraps the calls to the `__init__` of ``torch.nn.Module`` and replaces
the argument device with value 'meta' so that all modules are created as meta tensors.
"""
has_device = 'device' in inspect.signature(func).parameters
def layer_lazy_init(module, *args, **kwargs):
# if this module contains device argument
# we set it to meta to initialize as meta backend
if has_device:
kwargs['device'] = 'meta'
func(module, *args, **kwargs)
# if device is not found, we intialize it and convert to meta
if not has_device:
module.to('meta')
return layer_lazy_init
def _get_tmp_origin_func_ref(self, name):
"""
Generate a function name for consistency during caching and retrieving.
"""
return f'_orig_{name}'
def _patch_nn_init_funcs(self):
# patch nn.init functions
for name, func in self._nn_init_methods:
setattr(torch.nn.init, name, self._cache_init_func(func))
def _unpatch_nn_init_funcs(self):
# unpatch nn.init functions
for name, func in self._nn_init_methods:
setattr(torch.nn.init, name, func)
def _patch_submodule_init(self):
# patch classes __init__ methods
def _activate_wrap_init(cls):
cls.__orig_init__ = cls.__init__
cls.__init__ = self._wrap_module_init(cls.__init__)
substitute_init_recursively(self._torch_mod_cls, _activate_wrap_init, set())
def _unpatch_submodule_init(self):
def _recover_orig_init(cls):
cls.__init__ = cls.__orig_init__
substitute_init_recursively(self._torch_mod_cls, _recover_orig_init, set())
def _patch_torch_tensor_funcs(self):
# patch tensor value-setting functions
for func_name in self._torch_tensor_funcs:
origin_func_name = self._get_tmp_origin_func_ref(func_name)
origin_func = getattr(torch.Tensor, func_name)
setattr(torch.Tensor, origin_func_name, origin_func)
setattr(torch.Tensor, func_name, self._cache_init_func(origin_func))
def _unpatch_torch_tensor_funcs(self):
for func_name in self._torch_tensor_funcs:
origin_func_name = self._get_tmp_origin_func_ref(func_name)
origin_func = getattr(torch.Tensor, origin_func_name)
setattr(torch.Tensor, func_name, origin_func)
def __enter__(self):
self._patch_torch_tensor_funcs()
self._patch_nn_init_funcs()
if self._to_meta:
self._patch_submodule_init()
return self
def __exit__(self, *args, **kwargs):
if self._to_meta:
self._unpatch_submodule_init()
self._unpatch_nn_init_funcs()
self._unpatch_torch_tensor_funcs()
def lazy_init_parameters(self, model: torch.nn.Module, device='cpu'):
"""
Initialize the weights of the meta-tensor model.
Args:
model (`torch.nn.Module`): the model instantiated under the context.
device (str): the device on which weights are initialized
"""
def _init_recursively(module: nn.Module):
# recursively initialize the module
for mod in module.children():
_init_recursively(mod)
# initialize and shard tensors directly attached to the current module
for name, param in module.named_parameters(recurse=False):
_init_and_shard(module, name, param)
for name, buf in module.named_buffers(recurse=False):
_init_and_shard(module, name, buf)
@torch.no_grad()
def _init_and_shard(module, name, tensor):
# check whether the tensor is a buffer or parameter
is_param = isinstance(tensor, nn.parameter.Parameter)
# get sharding spec
dist_spec = getattr(tensor, 'dist_spec', None)
pg = getattr(tensor, 'pg', None)
comp_spec = getattr(tensor, 'comp_spec', None)
# convert the tensor from meta to materialized one
if tensor.is_meta:
materialized_tensor = torch.empty_like(tensor, device=device)
# if this tensor is a meta tensor, it must have an init function
assert tensor in self._intercepted_nn_init_func_cache
else:
materialized_tensor = tensor
# apply init function
if tensor in self._intercepted_nn_init_func_cache:
init_func, args, kwargs = self._intercepted_nn_init_func_cache[tensor][-1]
init_func(materialized_tensor, *args, **kwargs)
# convert it to ColoTensor or ColoParameter
if is_param:
tensor = ColoParameter.from_torch_tensor(materialized_tensor, requires_grad=tensor.requires_grad)
else:
tensor = ColoTensor.from_torch_tensor(materialized_tensor)
# override the original tensor
with torch.no_grad():
setattr(module, name, tensor)
# apply sharding
if dist_spec:
tensor.process_group = pg
tensor.set_tensor_spec(dist_spec, comp_spec)
_init_recursively(model)
return model

@ -2,13 +2,14 @@ import itertools
from collections import OrderedDict
from contextlib import nullcontext
from functools import partial
from typing import Dict, Iterator, List, Optional, Union, Tuple, Set
from typing import Dict, Iterator, List, Optional, Set, Tuple, Union
import torch
import torch.distributed as dist
import torch.nn as nn
from colossalai.checkpoint_io.utils import calculate_tensor_size
from colossalai.lazy import LazyTensor
from colossalai.logging import get_dist_logger
from colossalai.nn.parallel.data_parallel import ColoDDP, _cast_float, free_storage
from colossalai.tensor import ProcessGroup as ColoProcessGroup
@ -16,7 +17,6 @@ from colossalai.tensor import ReplicaSpec
from colossalai.tensor.colo_parameter import ColoParameter, ColoTensor, ColoTensorSpec
from colossalai.tensor.param_op_hook import ColoParamOpHookManager
from colossalai.utils import get_current_device, is_ddp_ignored
from colossalai.utils.model.experimental import LazyTensor
from .chunk import Chunk, ChunkManager, TensorState, init_chunk_manager
from .gemini_hook import GeminiZeROHook
@ -96,34 +96,38 @@ class ZeroDDP(ColoDDP):
param_name = m_name + '.' + p_name if m_name else p_name
self.name2param[param_name] = p_var
super().__init__(module, process_group=ColoProcessGroup())
self._non_persistent_buffers_set=self._get_non_persistent_buffers_set(module)
self._non_persistent_buffers_set = self._get_non_persistent_buffers_set(module)
self._cast_buffers()
def _get_non_persistent_buffers_set(self, module, memo: Optional[Set[nn.Module]] = None, prefix: str = '', remove_duplicate: bool = True):
r"""
Args:
memo: a memo to store the set of modules already added to the result
prefix: a prefix that will be added to the name of the module
remove_duplicate: whether to remove the duplicated module instances in the result
or not
"""
if memo is None:
memo = set()
self_non_persistent_set = set()
if module not in memo:
if remove_duplicate:
memo.add(module)
self_non_persistent_set = set(map(lambda key: prefix + ('.' if prefix else '') + key, module._non_persistent_buffers_set))
for name, sub_module in module._modules.items():
if sub_module is None:
continue
submodule_prefix = prefix + ('.' if prefix else '') + name
child_non_persistent_set = self._get_non_persistent_buffers_set(sub_module, memo, submodule_prefix, remove_duplicate)
self_non_persistent_set = set.union(self_non_persistent_set, child_non_persistent_set)
return self_non_persistent_set
def _get_non_persistent_buffers_set(self,
module,
memo: Optional[Set[nn.Module]] = None,
prefix: str = '',
remove_duplicate: bool = True):
r"""
Args:
memo: a memo to store the set of modules already added to the result
prefix: a prefix that will be added to the name of the module
remove_duplicate: whether to remove the duplicated module instances in the result
or not
"""
if memo is None:
memo = set()
self_non_persistent_set = set()
if module not in memo:
if remove_duplicate:
memo.add(module)
self_non_persistent_set = set(
map(lambda key: prefix + ('.' if prefix else '') + key, module._non_persistent_buffers_set))
for name, sub_module in module._modules.items():
if sub_module is None:
continue
submodule_prefix = prefix + ('.' if prefix else '') + name
child_non_persistent_set = self._get_non_persistent_buffers_set(sub_module, memo, submodule_prefix,
remove_duplicate)
self_non_persistent_set = set.union(self_non_persistent_set, child_non_persistent_set)
return self_non_persistent_set
def _post_forward(self):
"""This function is only triggered for inference.

@ -8,10 +8,10 @@ import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin
from colossalai.fx import is_compatible_with_meta
from colossalai.lazy.lazy_init import LazyInitContext
from colossalai.nn.optimizer import HybridAdam
from colossalai.tensor.colo_parameter import ColoParameter
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from colossalai.utils.model.experimental import LazyInitContext
from colossalai.zero import ColoInitContext
from tests.kit.model_zoo import model_zoo

@ -1,12 +1,13 @@
import random
from copy import deepcopy
from typing import Any, Callable, Optional, Tuple
import numpy as np
import torch
from packaging import version
from colossalai.lazy.lazy_init import LazyInitContext, LazyTensor, _MyTensor
from colossalai.tensor.d_tensor.layout_converter import to_global
from colossalai.utils.model.experimental import LazyInitContext, LazyTensor, _MyTensor
from tests.kit.model_zoo.registry import ModelAttribute
SUPPORT_LAZY = version.parse(torch.__version__) >= version.parse('1.12.0')
@ -31,6 +32,9 @@ def assert_model_equal(m1: torch.nn.Module, m2: torch.nn.Module) -> None:
assert n1 == n2
assert torch.equal(t1, t2), f'{n1} {t1} vs {t2}'
for p1, p2 in zip(m1.parameters(), m2.parameters()):
assert p1.requires_grad == p2.requires_grad
def assert_forward_equal(m1: torch.nn.Module, m2: torch.nn.Module, data_gen_fn: Callable[[], dict],
output_transform_fn: Callable[[Any], dict]) -> None:
@ -65,10 +69,14 @@ def check_lazy_init(entry: TestingEntry, seed: int = 42, verbose: bool = False,
ctx = LazyInitContext()
with ctx:
deferred_model = model_fn()
copied_deferred_model = deepcopy(deferred_model)
deferred_model = ctx.materialize(deferred_model, verbose=verbose)
copied_deferred_model = ctx.materialize(copied_deferred_model, verbose=verbose)
assert_model_equal(model, deferred_model)
assert_model_equal(deferred_model, copied_deferred_model)
if check_forward:
assert_forward_equal(model, deferred_model, data_gen_fn, output_transform_fn)
assert_forward_equal(deferred_model, copied_deferred_model, data_gen_fn, output_transform_fn)
if verbose:
print(f'{model.__class__.__name__} pass')

@ -12,7 +12,7 @@ from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from colossalai.utils.common import print_rank_0
try:
from colossalai.utils.model.experimental import LazyInitContext, LazyTensor, _MyTensor
from colossalai.lazy.lazy_init import LazyInitContext, LazyTensor, _MyTensor
except:
pass
from lazy_init_utils import SUPPORT_LAZY, assert_dist_model_equal, set_seed

@ -1,51 +0,0 @@
import torch
from colossalai.utils.model.lazy_init_context import LazyInitContext
from torchvision.models import resnet34
import random
import numpy as np
MANUAL_SEED = 0
random.seed(MANUAL_SEED)
np.random.seed(MANUAL_SEED)
torch.manual_seed(MANUAL_SEED)
def test_lazy_init_with_meta():
ctx = LazyInitContext(to_meta=True)
with ctx:
model = resnet34(num_classes=10)
for param in model.parameters():
assert param.is_meta
for buffer in model.buffers():
assert buffer.is_meta
ctx.lazy_init_parameters(model)
for name, param in model.named_parameters():
assert not param.is_meta, name
for buffer in model.buffers():
assert not buffer.is_meta
def test_lazy_init_without_meta():
ctx = LazyInitContext(to_meta=False)
with ctx:
model = resnet34(num_classes=10)
for param in model.parameters():
assert not param.is_meta
for buffer in model.buffers():
assert not buffer.is_meta
conv1_weight_before_init = model.conv1.weight.clone()
ctx.lazy_init_parameters(model)
conv1_weight_after_init = model.conv1.weight.clone()
assert not torch.allclose(conv1_weight_after_init, conv1_weight_before_init)
if __name__ == '__main__':
test_lazy_init_with_meta()
test_lazy_init_without_meta()
Loading…
Cancel
Save