[lazyinit] add correctness verification (#3147)

* [lazyinit] fix shared module

* [tests] add lazy init test utils

* [tests] add torchvision for lazy init

* [lazyinit] fix pre op fn

* [lazyinit] handle legacy constructor

* [tests] refactor lazy init test models

* [tests] refactor lazy init test utils

* [lazyinit] fix ops don't support meta

* [tests] lazy init test timm models

* [lazyinit] fix set data

* [lazyinit] handle apex layers

* [tests] lazy init test transformers models

* [tests] lazy init test torchaudio models

* [lazyinit] fix import path

* [tests] lazy init test torchrec models

* [tests] update torch version in CI

* [tests] revert torch version in CI

* [tests] skip lazy init test
pull/3162/head
ver217 2023-03-17 13:49:04 +08:00 committed by GitHub
parent 3c01280a56
commit 6ae8ed0407
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 226 additions and 32 deletions

View File

@ -1,17 +1,16 @@
from typing import Callable, Optional, Union from typing import Callable, List, Optional, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch import Tensor from torch import Tensor
from torch.utils._pytree import tree_map from torch.utils._pytree import tree_map
from colossalai.fx.profiler import MetaTensor from colossalai.fx.profiler.tensor import MetaTensor
# reference: https://pytorch.org/cppdocs/notes/tensor_creation.html # reference: https://pytorch.org/cppdocs/notes/tensor_creation.html
_TorchFactoryMethod = [ _NORMAL_FACTORY = [
"arange", "arange",
"empty", "empty",
"eye",
"full", "full",
"linspace", "linspace",
"logspace", "logspace",
@ -24,17 +23,39 @@ _TorchFactoryMethod = [
"tensor", "tensor",
] ]
# factory function that does not support meta tensor backend
_NO_META_FACTORY = [
"eye",
]
_EARLY_MATERIALIZED_OPS = ['__getitem__', 'split'] _EARLY_MATERIALIZED_OPS = ['__getitem__', 'split']
_LEGACY_TENSOR_CONSTRUCTOR = {
'FloatTensor': torch.float,
'DoubleTensor': torch.double,
'HalfTensor': torch.half,
'BFloat16Tensor': torch.bfloat16,
'ByteTensor': torch.uint8,
'CharTensor': torch.int8,
'ShortTensor': torch.short,
'IntTensor': torch.int,
'LongTensor': torch.long,
'BoolTensor': torch.bool,
}
class _MyTensor(Tensor): class _MyTensor(Tensor):
"""This class is only for correctness verification. """This class is only for correctness verification.
""" """
_pre_op_fn: Callable[['LazyTensor'], None] = lambda *args: None _pre_op_fn: Callable[['LazyTensor'], None] = lambda *args: None
def __new__(cls, func, *args, dtype=None, device=None, **kwargs) -> '_MyTensor': def __new__(cls, func, *args, concrete_data=None, **kwargs) -> '_MyTensor':
cls._pre_op_fn() cls._pre_op_fn()
data = func(*args, dtype=dtype, device=device, **kwargs) if concrete_data is not None:
# uniform api as LazyTensor
data = concrete_data
else:
data = func(*args, **kwargs)
return Tensor._make_subclass(cls, data, require_grad=data.requires_grad) return Tensor._make_subclass(cls, data, require_grad=data.requires_grad)
@classmethod @classmethod
@ -66,11 +87,13 @@ class LazyTensor(torch.Tensor):
>>> x.add_(1) # modifying origin tensor after cloning leads to wrong materialization >>> x.add_(1) # modifying origin tensor after cloning leads to wrong materialization
>>> z = x.tolist() >>> z = x.tolist()
>>> x.zeros_() # modifying origin tensor after cloning tolist is not allowed >>> x.zeros_() # modifying origin tensor after cloning tolist is not allowed
>>> x.data = torch.rand(2, 3) # directly set data of a lazy tensor is not allowed >>> nn.utils.weight_norm(self.conv, name="weight", dim=2) # applying weight norm on a lazy tensor is not allowed
2. Cases that ``LazyTensor`` becomes eager (early materialization). 2. Cases that ``LazyTensor`` becomes eager (early materialization).
>>> b = a[:, 2:] # get a slice of a lazy tensor triggers early materialization >>> b = a[:, 2:] # get a slice of a lazy tensor triggers early materialization
>>> chunks = a.split(3) # this also triggers early materialization >>> chunks = a.split(3) # this also triggers early materialization
>>> x.data = torch.rand(2, 3) # directly setting data of a lazy tensor triggers early materialization
""" """
@ -79,12 +102,16 @@ class LazyTensor(torch.Tensor):
_pre_op_fn: Callable[['LazyTensor'], None] = lambda *args: None _pre_op_fn: Callable[['LazyTensor'], None] = lambda *args: None
@staticmethod @staticmethod
def __new__(cls, func, *args, meta_data=None, **kwargs): def __new__(cls, func, *args, meta_data=None, concrete_data=None, **kwargs):
if meta_data is None: if concrete_data is not None:
device = kwargs.get('device', 'cpu') # some ops don't support meta backend and should have concrete data
elem = func(*args, **{**kwargs, 'device': 'meta'}) elem = concrete_data
meta_data = MetaTensor(elem, fake_device=device) else:
elem = meta_data._tensor if meta_data is None:
device = kwargs.get('device', 'cpu')
elem = func(*args, **{**kwargs, 'device': 'meta'})
meta_data = MetaTensor(elem, fake_device=device)
elem = meta_data._tensor
r = torch.Tensor._make_wrapper_subclass(cls, r = torch.Tensor._make_wrapper_subclass(cls,
elem.size(), elem.size(),
strides=elem.stride(), strides=elem.stride(),
@ -96,10 +123,10 @@ class LazyTensor(torch.Tensor):
r._meta_data = meta_data r._meta_data = meta_data
return r return r
def __init__(self, func, *args, meta_data=None, **kwargs): def __init__(self, func, *args, meta_data=None, concrete_data=None, **kwargs):
self._factory_method = (func, args, kwargs) # (func, args, kwargs) self._factory_method = (func, args, kwargs) # (func, args, kwargs)
self._op_buffer = [] # (func, args, kwargs, replace) self._op_buffer = [] # (func, args, kwargs, replace)
self._materialized_data: Optional[torch.Tensor] = None # materialized data self._materialized_data: Optional[torch.Tensor] = concrete_data # materialized data
def materialize(self) -> torch.Tensor: def materialize(self) -> torch.Tensor:
"""Materialize the ``LazyTensor`` to ``torch.Tensor``. """Materialize the ``LazyTensor`` to ``torch.Tensor``.
@ -212,7 +239,7 @@ class LazyTensor(torch.Tensor):
if isinstance(x, LazyTensor): if isinstance(x, LazyTensor):
if x._materialized_data is not None: if x._materialized_data is not None:
# for early materialized tensor, use its materialized data directly # for early materialized tensor, use its materialized data directly
return x._materialized_data return x._materialized_data.data
t = x if is_inplace else x.clone() t = x if is_inplace else x.clone()
t._op_buffer.append((func, args, kwargs)) t._op_buffer.append((func, args, kwargs))
meta = x._meta_data.data meta = x._meta_data.data
@ -232,13 +259,10 @@ class LazyTensor(torch.Tensor):
return lazy_y return lazy_y
elif type(y) is Tensor: elif type(y) is Tensor:
# for early materialized tensor # for early materialized tensor
with torch._C.DisableTorchFunction(): return LazyTensor(lambda: None, concrete_data=y)
meta = MetaTensor(y.new_empty(y.shape, dtype=y.dtype, device='meta'), fake_device=y.device)
lazy_y = LazyTensor(lambda: None, meta_data=meta)
lazy_y._materialized_data = y
return lazy_y
return y return y
cls._pre_op_fn()
o = func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs)) o = func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs))
if isinstance(o, (tuple, list)): if isinstance(o, (tuple, list)):
return type(o)(wrap(y, i=i) for i, y in enumerate(o)) return type(o)(wrap(y, i=i) for i, y in enumerate(o))
@ -266,7 +290,10 @@ class LazyTensor(torch.Tensor):
@data.setter @data.setter
def data(self, other: 'LazyTensor'): def data(self, other: 'LazyTensor'):
raise NotImplementedError if other is self:
return
# TODO(ver217): to avoid infinity recursion, do early materialization
self._materialized_data = other._materialize_data()
def tolist(self) -> list: def tolist(self) -> list:
t = self.materialize() t = self.materialize()
@ -330,18 +357,61 @@ class LazyInitContext:
return wrapper, target return wrapper, target
def wrap_legacy_constructor(target, dtype):
# legacy constructor (e.g. torch.LongTensor())
def wrapper(*args, **kwargs):
if len(args) == 1 and isinstance(args[0], torch.Tensor):
# (Tensor other)
return args[0]
elif len(args) == 1:
# (object data, *, torch.device device)
kwargs = {**kwargs, 'dtype': dtype}
replaced, orig = self.overrides['tensor']
return replaced(*args, **kwargs)
elif _is_int_tuple(args):
# (tuple of ints size, *, torch.device device)
kwargs = {**kwargs, 'dtype': dtype}
replaced, orig = self.overrides['empty']
return replaced(*args, **kwargs)
else:
raise TypeError(
f'new() received an invalid combination of arguments - got {tuple(type(x) for x in args)}, but expected one of:\n * (Tensor other)\n * (tuple of ints size, *, torch.device device)\n * (object data, *, torch.device device)'
)
return wrapper, target
def wrap_no_meta_factory(target):
# factory functions which don't support meta tensor backend
def wrapper(*args, **kwargs):
tensor = target(*args, **kwargs)
return self.tensor_cls(lambda: None, concrete_data=tensor)
return wrapper, target
self.overrides = { self.overrides = {
target: wrap_factory_method(getattr(torch, target)) target: wrap_factory_method(getattr(torch, target))
for target in _TorchFactoryMethod for target in _NORMAL_FACTORY
if callable(getattr(torch, target, None)) if callable(getattr(torch, target, None))
} }
self.overrides.update({ self.overrides.update({
target + '_like': wrap_factory_like_method(getattr(torch, target), getattr(torch, target + '_like')) target + '_like': wrap_factory_like_method(getattr(torch, target), getattr(torch, target + '_like'))
for target in _TorchFactoryMethod for target in _NORMAL_FACTORY
if callable(getattr(torch, target + '_like', None)) if callable(getattr(torch, target + '_like', None))
}) })
self.overrides.update({
target: wrap_legacy_constructor(getattr(torch, target), dtype)
for target, dtype in _LEGACY_TENSOR_CONSTRUCTOR.items()
if callable(getattr(torch, target, None))
})
self.overrides.update({
target: wrap_no_meta_factory(getattr(torch, target))
for target in _NO_META_FACTORY
if callable(getattr(torch, target, None))
})
for name, (wrapper, orig) in self.overrides.items(): for name, (wrapper, orig) in self.overrides.items():
setattr(torch, name, wrapper) setattr(torch, name, wrapper)
@ -363,34 +433,65 @@ class LazyInitContext:
param_lazy_cnt = 0 param_lazy_cnt = 0
buf_cnt = 0 buf_cnt = 0
buf_lazy_cnt = 0 buf_lazy_cnt = 0
non_lazy_numel = 0
# do post cleaning to handle shared parameter
visited_lazy_tensors: List[LazyTensor] = []
# handle shared module
visited_modules = set()
@torch.no_grad() @torch.no_grad()
def init_recursively(module: nn.Module): def init_recursively(module: nn.Module):
nonlocal param_cnt, param_lazy_cnt, buf_cnt, buf_lazy_cnt nonlocal param_cnt, param_lazy_cnt, buf_cnt, buf_lazy_cnt, non_lazy_numel
# recursively initialize the module # recursively initialize the module
for mod in module.children(): for mod in module.children():
init_recursively(mod) if id(mod) not in visited_modules:
visited_modules.add(id(mod))
init_recursively(mod)
# initialize tensors directly attached to the current module # initialize tensors directly attached to the current module
for name, param in module.named_parameters(recurse=False): for name, param in module.named_parameters(recurse=False):
if verbose: if verbose:
param_cnt += 1 param_cnt += 1
if param._materialized_data is None: if getattr(param, '_materialized_data', False) is None:
# if no _materialized_data attr, the tensor is not lazy
param_lazy_cnt += 1 param_lazy_cnt += 1
setattr(module, name, param.materialize()) else:
param.clean() non_lazy_numel += param.numel()
if hasattr(param, 'materialize'):
# TODO(ver217): apex layers cannot be captured
visited_lazy_tensors.append(param)
setattr(module, name, param.materialize())
for name, buf in module.named_buffers(recurse=False): for name, buf in module.named_buffers(recurse=False):
if verbose: if verbose:
buf_cnt += 1 buf_cnt += 1
if buf._materialized_data is None: if getattr(buf, "_materialized_data", False) is None:
# if no _materialized_data attr, the tensor is not lazy
buf_lazy_cnt += 1 buf_lazy_cnt += 1
setattr(module, name, buf.materialize()) else:
buf.clean() non_lazy_numel += buf.numel()
if hasattr(buf, 'materialize'):
# TODO(ver217): apex layers cannot be captured
visited_lazy_tensors.append(buf)
setattr(module, name, buf.materialize())
init_recursively(module) init_recursively(module)
for t in visited_lazy_tensors:
t.clean()
if verbose: if verbose:
print(f'Param lazy rate: {param_lazy_cnt}/{param_cnt}') print(f'Param lazy rate: {param_lazy_cnt}/{param_cnt}')
print(f'Buffer lazy rate: {buf_lazy_cnt}/{buf_cnt}') print(f'Buffer lazy rate: {buf_lazy_cnt}/{buf_cnt}')
print(f'Non-lazy numel: {non_lazy_numel} ({non_lazy_numel/1024**2:.3f} M)')
return module return module
def _is_int_tuple(args) -> bool:
if not isinstance(args, tuple):
return False
for x in args:
if not isinstance(x, int):
return False
return True

View File

@ -0,0 +1 @@
from .torchrec import *

View File

@ -0,0 +1,23 @@
import pytest
from tests.kit.model_zoo import model_zoo
# FIXME(ver217): uncomment this line
# from utils import check_lazy_init
# FIXME(ver217): temporarily skip this test since torch 1.11 does not fully support meta tensor
@pytest.mark.skip
@pytest.mark.parametrize('subset', ['torchvision', 'diffusers', 'timm', 'transformers', 'torchaudio', 'deepfm', 'dlrm'])
def test_torchvision_models_lazy_init(subset):
sub_model_zoo = model_zoo.get_sub_registry(subset)
for name, entry in sub_model_zoo.items():
# TODO(ver217): lazy init does not support weight norm, skip these models
if name in ('torchaudio_wav2vec2_base', 'torchaudio_hubert_base'):
continue
# FIXME(ver217): uncomment this line
# check_lazy_init(entry, verbose=True)
if __name__ == '__main__':
test_torchvision_models_lazy_init('torchvision')

View File

@ -0,0 +1,69 @@
import random
from typing import Any, Callable, Optional, Tuple
import numpy as np
import torch
from colossalai.utils.model.experimental import LazyInitContext, LazyTensor, _MyTensor
from tests.kit.model_zoo.registry import ModelAttribute
# model_fn, data_gen_fn, output_transform_fn, model_attr
TestingEntry = Tuple[Callable[[], torch.nn.Module], Callable[[], dict], Callable[[], dict], Optional[ModelAttribute]]
def set_seed(seed: int) -> None:
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
def assert_model_eqaual(m1: torch.nn.Module, m2: torch.nn.Module) -> None:
s1 = m1.state_dict()
s2 = m2.state_dict()
assert len(s1) == len(s2), f'len {len(s1)} vs {len(s2)}'
for (n1, t1), (n2, t2) in zip(s1.items(), s2.items()):
assert n1 == n2
assert torch.equal(t1, t2), f'{n1} {t1} vs {t2}'
def assert_forward_equal(m1: torch.nn.Module, m2: torch.nn.Module, data_gen_fn: Callable[[], dict],
output_transform_fn: Callable[[Any], dict]) -> None:
data = data_gen_fn()
m1.eval()
m2.eval()
# run forward
with torch.no_grad():
outputs1 = m1(**data)
outputs2 = m2(**data)
# compare output
transformed_out1 = output_transform_fn(outputs1)
transformed_out2 = output_transform_fn(outputs2)
assert len(transformed_out1) == len(transformed_out2)
for key, out1 in transformed_out1.items():
out2 = transformed_out2[key]
assert torch.allclose(out1, out2, atol=1e-5), \
f'{m1.__class__.__name__} has inconsistent outputs, {out1} vs {out2}'
def check_lazy_init(entry: TestingEntry, seed: int = 42, verbose: bool = False, check_forward: bool = False) -> None:
model_fn, data_gen_fn, output_transform_fn, model_attr = entry
_MyTensor._pre_op_fn = lambda *args: set_seed(seed)
LazyTensor._pre_op_fn = lambda *args: set_seed(seed)
ctx = LazyInitContext(tensor_cls=_MyTensor)
with ctx:
model = model_fn()
ctx = LazyInitContext()
with ctx:
deferred_model = model_fn()
deferred_model = ctx.materialize(deferred_model, verbose=verbose)
assert_model_eqaual(model, deferred_model)
if check_forward:
assert_forward_equal(model, deferred_model, data_gen_fn, output_transform_fn)
if verbose:
print(f'{model.__class__.__name__} pass')