diff --git a/colossalai/utils/model/experimental.py b/colossalai/utils/model/experimental.py index b8eb742f8..00cb532d9 100644 --- a/colossalai/utils/model/experimental.py +++ b/colossalai/utils/model/experimental.py @@ -1,17 +1,16 @@ -from typing import Callable, Optional, Union +from typing import Callable, List, Optional, Union import torch import torch.nn as nn from torch import Tensor from torch.utils._pytree import tree_map -from colossalai.fx.profiler import MetaTensor +from colossalai.fx.profiler.tensor import MetaTensor # reference: https://pytorch.org/cppdocs/notes/tensor_creation.html -_TorchFactoryMethod = [ +_NORMAL_FACTORY = [ "arange", "empty", - "eye", "full", "linspace", "logspace", @@ -24,17 +23,39 @@ _TorchFactoryMethod = [ "tensor", ] +# factory function that does not support meta tensor backend +_NO_META_FACTORY = [ + "eye", +] + _EARLY_MATERIALIZED_OPS = ['__getitem__', 'split'] +_LEGACY_TENSOR_CONSTRUCTOR = { + 'FloatTensor': torch.float, + 'DoubleTensor': torch.double, + 'HalfTensor': torch.half, + 'BFloat16Tensor': torch.bfloat16, + 'ByteTensor': torch.uint8, + 'CharTensor': torch.int8, + 'ShortTensor': torch.short, + 'IntTensor': torch.int, + 'LongTensor': torch.long, + 'BoolTensor': torch.bool, +} + class _MyTensor(Tensor): """This class is only for correctness verification. """ _pre_op_fn: Callable[['LazyTensor'], None] = lambda *args: None - def __new__(cls, func, *args, dtype=None, device=None, **kwargs) -> '_MyTensor': + def __new__(cls, func, *args, concrete_data=None, **kwargs) -> '_MyTensor': cls._pre_op_fn() - data = func(*args, dtype=dtype, device=device, **kwargs) + if concrete_data is not None: + # uniform api as LazyTensor + data = concrete_data + else: + data = func(*args, **kwargs) return Tensor._make_subclass(cls, data, require_grad=data.requires_grad) @classmethod @@ -66,11 +87,13 @@ class LazyTensor(torch.Tensor): >>> x.add_(1) # modifying origin tensor after cloning leads to wrong materialization >>> z = x.tolist() >>> x.zeros_() # modifying origin tensor after cloning tolist is not allowed - >>> x.data = torch.rand(2, 3) # directly set data of a lazy tensor is not allowed + >>> nn.utils.weight_norm(self.conv, name="weight", dim=2) # applying weight norm on a lazy tensor is not allowed + 2. Cases that ``LazyTensor`` becomes eager (early materialization). >>> b = a[:, 2:] # get a slice of a lazy tensor triggers early materialization >>> chunks = a.split(3) # this also triggers early materialization + >>> x.data = torch.rand(2, 3) # directly setting data of a lazy tensor triggers early materialization """ @@ -79,12 +102,16 @@ class LazyTensor(torch.Tensor): _pre_op_fn: Callable[['LazyTensor'], None] = lambda *args: None @staticmethod - def __new__(cls, func, *args, meta_data=None, **kwargs): - if meta_data is None: - device = kwargs.get('device', 'cpu') - elem = func(*args, **{**kwargs, 'device': 'meta'}) - meta_data = MetaTensor(elem, fake_device=device) - elem = meta_data._tensor + def __new__(cls, func, *args, meta_data=None, concrete_data=None, **kwargs): + if concrete_data is not None: + # some ops don't support meta backend and should have concrete data + elem = concrete_data + else: + if meta_data is None: + device = kwargs.get('device', 'cpu') + elem = func(*args, **{**kwargs, 'device': 'meta'}) + meta_data = MetaTensor(elem, fake_device=device) + elem = meta_data._tensor r = torch.Tensor._make_wrapper_subclass(cls, elem.size(), strides=elem.stride(), @@ -96,10 +123,10 @@ class LazyTensor(torch.Tensor): r._meta_data = meta_data return r - def __init__(self, func, *args, meta_data=None, **kwargs): + def __init__(self, func, *args, meta_data=None, concrete_data=None, **kwargs): self._factory_method = (func, args, kwargs) # (func, args, kwargs) self._op_buffer = [] # (func, args, kwargs, replace) - self._materialized_data: Optional[torch.Tensor] = None # materialized data + self._materialized_data: Optional[torch.Tensor] = concrete_data # materialized data def materialize(self) -> torch.Tensor: """Materialize the ``LazyTensor`` to ``torch.Tensor``. @@ -212,7 +239,7 @@ class LazyTensor(torch.Tensor): if isinstance(x, LazyTensor): if x._materialized_data is not None: # for early materialized tensor, use its materialized data directly - return x._materialized_data + return x._materialized_data.data t = x if is_inplace else x.clone() t._op_buffer.append((func, args, kwargs)) meta = x._meta_data.data @@ -232,13 +259,10 @@ class LazyTensor(torch.Tensor): return lazy_y elif type(y) is Tensor: # for early materialized tensor - with torch._C.DisableTorchFunction(): - meta = MetaTensor(y.new_empty(y.shape, dtype=y.dtype, device='meta'), fake_device=y.device) - lazy_y = LazyTensor(lambda: None, meta_data=meta) - lazy_y._materialized_data = y - return lazy_y + return LazyTensor(lambda: None, concrete_data=y) return y + cls._pre_op_fn() o = func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs)) if isinstance(o, (tuple, list)): return type(o)(wrap(y, i=i) for i, y in enumerate(o)) @@ -266,7 +290,10 @@ class LazyTensor(torch.Tensor): @data.setter def data(self, other: 'LazyTensor'): - raise NotImplementedError + if other is self: + return + # TODO(ver217): to avoid infinity recursion, do early materialization + self._materialized_data = other._materialize_data() def tolist(self) -> list: t = self.materialize() @@ -330,18 +357,61 @@ class LazyInitContext: return wrapper, target + def wrap_legacy_constructor(target, dtype): + # legacy constructor (e.g. torch.LongTensor()) + def wrapper(*args, **kwargs): + if len(args) == 1 and isinstance(args[0], torch.Tensor): + # (Tensor other) + return args[0] + elif len(args) == 1: + # (object data, *, torch.device device) + kwargs = {**kwargs, 'dtype': dtype} + replaced, orig = self.overrides['tensor'] + return replaced(*args, **kwargs) + elif _is_int_tuple(args): + # (tuple of ints size, *, torch.device device) + kwargs = {**kwargs, 'dtype': dtype} + replaced, orig = self.overrides['empty'] + return replaced(*args, **kwargs) + else: + raise TypeError( + f'new() received an invalid combination of arguments - got {tuple(type(x) for x in args)}, but expected one of:\n * (Tensor other)\n * (tuple of ints size, *, torch.device device)\n * (object data, *, torch.device device)' + ) + + return wrapper, target + + def wrap_no_meta_factory(target): + # factory functions which don't support meta tensor backend + def wrapper(*args, **kwargs): + tensor = target(*args, **kwargs) + return self.tensor_cls(lambda: None, concrete_data=tensor) + + return wrapper, target + self.overrides = { target: wrap_factory_method(getattr(torch, target)) - for target in _TorchFactoryMethod + for target in _NORMAL_FACTORY if callable(getattr(torch, target, None)) } self.overrides.update({ target + '_like': wrap_factory_like_method(getattr(torch, target), getattr(torch, target + '_like')) - for target in _TorchFactoryMethod + for target in _NORMAL_FACTORY if callable(getattr(torch, target + '_like', None)) }) + self.overrides.update({ + target: wrap_legacy_constructor(getattr(torch, target), dtype) + for target, dtype in _LEGACY_TENSOR_CONSTRUCTOR.items() + if callable(getattr(torch, target, None)) + }) + + self.overrides.update({ + target: wrap_no_meta_factory(getattr(torch, target)) + for target in _NO_META_FACTORY + if callable(getattr(torch, target, None)) + }) + for name, (wrapper, orig) in self.overrides.items(): setattr(torch, name, wrapper) @@ -363,34 +433,65 @@ class LazyInitContext: param_lazy_cnt = 0 buf_cnt = 0 buf_lazy_cnt = 0 + non_lazy_numel = 0 + + # do post cleaning to handle shared parameter + visited_lazy_tensors: List[LazyTensor] = [] + # handle shared module + visited_modules = set() @torch.no_grad() def init_recursively(module: nn.Module): - nonlocal param_cnt, param_lazy_cnt, buf_cnt, buf_lazy_cnt + nonlocal param_cnt, param_lazy_cnt, buf_cnt, buf_lazy_cnt, non_lazy_numel # recursively initialize the module for mod in module.children(): - init_recursively(mod) + if id(mod) not in visited_modules: + visited_modules.add(id(mod)) + init_recursively(mod) # initialize tensors directly attached to the current module for name, param in module.named_parameters(recurse=False): if verbose: param_cnt += 1 - if param._materialized_data is None: + if getattr(param, '_materialized_data', False) is None: + # if no _materialized_data attr, the tensor is not lazy param_lazy_cnt += 1 - setattr(module, name, param.materialize()) - param.clean() + else: + non_lazy_numel += param.numel() + if hasattr(param, 'materialize'): + # TODO(ver217): apex layers cannot be captured + visited_lazy_tensors.append(param) + setattr(module, name, param.materialize()) for name, buf in module.named_buffers(recurse=False): if verbose: buf_cnt += 1 - if buf._materialized_data is None: + if getattr(buf, "_materialized_data", False) is None: + # if no _materialized_data attr, the tensor is not lazy buf_lazy_cnt += 1 - setattr(module, name, buf.materialize()) - buf.clean() + else: + non_lazy_numel += buf.numel() + if hasattr(buf, 'materialize'): + # TODO(ver217): apex layers cannot be captured + visited_lazy_tensors.append(buf) + setattr(module, name, buf.materialize()) init_recursively(module) + for t in visited_lazy_tensors: + t.clean() + if verbose: print(f'Param lazy rate: {param_lazy_cnt}/{param_cnt}') print(f'Buffer lazy rate: {buf_lazy_cnt}/{buf_cnt}') + print(f'Non-lazy numel: {non_lazy_numel} ({non_lazy_numel/1024**2:.3f} M)') return module + + +def _is_int_tuple(args) -> bool: + if not isinstance(args, tuple): + return False + for x in args: + if not isinstance(x, int): + return False + return True diff --git a/tests/kit/model_zoo/torchrec/__init__.py b/tests/kit/model_zoo/torchrec/__init__.py new file mode 100644 index 000000000..43952e699 --- /dev/null +++ b/tests/kit/model_zoo/torchrec/__init__.py @@ -0,0 +1 @@ +from .torchrec import * diff --git a/tests/test_utils/test_lazy_init/test_models.py b/tests/test_utils/test_lazy_init/test_models.py new file mode 100644 index 000000000..9faddecba --- /dev/null +++ b/tests/test_utils/test_lazy_init/test_models.py @@ -0,0 +1,23 @@ +import pytest + +from tests.kit.model_zoo import model_zoo + +# FIXME(ver217): uncomment this line +# from utils import check_lazy_init + + +# FIXME(ver217): temporarily skip this test since torch 1.11 does not fully support meta tensor +@pytest.mark.skip +@pytest.mark.parametrize('subset', ['torchvision', 'diffusers', 'timm', 'transformers', 'torchaudio', 'deepfm', 'dlrm']) +def test_torchvision_models_lazy_init(subset): + sub_model_zoo = model_zoo.get_sub_registry(subset) + for name, entry in sub_model_zoo.items(): + # TODO(ver217): lazy init does not support weight norm, skip these models + if name in ('torchaudio_wav2vec2_base', 'torchaudio_hubert_base'): + continue + # FIXME(ver217): uncomment this line + # check_lazy_init(entry, verbose=True) + + +if __name__ == '__main__': + test_torchvision_models_lazy_init('torchvision') diff --git a/tests/test_utils/test_lazy_init/utils.py b/tests/test_utils/test_lazy_init/utils.py new file mode 100644 index 000000000..47ba534bc --- /dev/null +++ b/tests/test_utils/test_lazy_init/utils.py @@ -0,0 +1,69 @@ +import random +from typing import Any, Callable, Optional, Tuple + +import numpy as np +import torch + +from colossalai.utils.model.experimental import LazyInitContext, LazyTensor, _MyTensor +from tests.kit.model_zoo.registry import ModelAttribute + +# model_fn, data_gen_fn, output_transform_fn, model_attr +TestingEntry = Tuple[Callable[[], torch.nn.Module], Callable[[], dict], Callable[[], dict], Optional[ModelAttribute]] + + +def set_seed(seed: int) -> None: + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + + +def assert_model_eqaual(m1: torch.nn.Module, m2: torch.nn.Module) -> None: + s1 = m1.state_dict() + s2 = m2.state_dict() + + assert len(s1) == len(s2), f'len {len(s1)} vs {len(s2)}' + + for (n1, t1), (n2, t2) in zip(s1.items(), s2.items()): + assert n1 == n2 + assert torch.equal(t1, t2), f'{n1} {t1} vs {t2}' + + +def assert_forward_equal(m1: torch.nn.Module, m2: torch.nn.Module, data_gen_fn: Callable[[], dict], + output_transform_fn: Callable[[Any], dict]) -> None: + data = data_gen_fn() + + m1.eval() + m2.eval() + # run forward + with torch.no_grad(): + outputs1 = m1(**data) + outputs2 = m2(**data) + + # compare output + transformed_out1 = output_transform_fn(outputs1) + transformed_out2 = output_transform_fn(outputs2) + + assert len(transformed_out1) == len(transformed_out2) + + for key, out1 in transformed_out1.items(): + out2 = transformed_out2[key] + assert torch.allclose(out1, out2, atol=1e-5), \ + f'{m1.__class__.__name__} has inconsistent outputs, {out1} vs {out2}' + + +def check_lazy_init(entry: TestingEntry, seed: int = 42, verbose: bool = False, check_forward: bool = False) -> None: + model_fn, data_gen_fn, output_transform_fn, model_attr = entry + _MyTensor._pre_op_fn = lambda *args: set_seed(seed) + LazyTensor._pre_op_fn = lambda *args: set_seed(seed) + ctx = LazyInitContext(tensor_cls=_MyTensor) + with ctx: + model = model_fn() + ctx = LazyInitContext() + with ctx: + deferred_model = model_fn() + deferred_model = ctx.materialize(deferred_model, verbose=verbose) + assert_model_eqaual(model, deferred_model) + if check_forward: + assert_forward_equal(model, deferred_model, data_gen_fn, output_transform_fn) + if verbose: + print(f'{model.__class__.__name__} pass')