mirror of https://github.com/hpcaitech/ColossalAI
[lazyinit] add correctness verification (#3147)
* [lazyinit] fix shared module * [tests] add lazy init test utils * [tests] add torchvision for lazy init * [lazyinit] fix pre op fn * [lazyinit] handle legacy constructor * [tests] refactor lazy init test models * [tests] refactor lazy init test utils * [lazyinit] fix ops don't support meta * [tests] lazy init test timm models * [lazyinit] fix set data * [lazyinit] handle apex layers * [tests] lazy init test transformers models * [tests] lazy init test torchaudio models * [lazyinit] fix import path * [tests] lazy init test torchrec models * [tests] update torch version in CI * [tests] revert torch version in CI * [tests] skip lazy init testpull/3162/head
parent
3c01280a56
commit
6ae8ed0407
|
@ -1,17 +1,16 @@
|
|||
from typing import Callable, Optional, Union
|
||||
from typing import Callable, List, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch import Tensor
|
||||
from torch.utils._pytree import tree_map
|
||||
|
||||
from colossalai.fx.profiler import MetaTensor
|
||||
from colossalai.fx.profiler.tensor import MetaTensor
|
||||
|
||||
# reference: https://pytorch.org/cppdocs/notes/tensor_creation.html
|
||||
_TorchFactoryMethod = [
|
||||
_NORMAL_FACTORY = [
|
||||
"arange",
|
||||
"empty",
|
||||
"eye",
|
||||
"full",
|
||||
"linspace",
|
||||
"logspace",
|
||||
|
@ -24,17 +23,39 @@ _TorchFactoryMethod = [
|
|||
"tensor",
|
||||
]
|
||||
|
||||
# factory function that does not support meta tensor backend
|
||||
_NO_META_FACTORY = [
|
||||
"eye",
|
||||
]
|
||||
|
||||
_EARLY_MATERIALIZED_OPS = ['__getitem__', 'split']
|
||||
|
||||
_LEGACY_TENSOR_CONSTRUCTOR = {
|
||||
'FloatTensor': torch.float,
|
||||
'DoubleTensor': torch.double,
|
||||
'HalfTensor': torch.half,
|
||||
'BFloat16Tensor': torch.bfloat16,
|
||||
'ByteTensor': torch.uint8,
|
||||
'CharTensor': torch.int8,
|
||||
'ShortTensor': torch.short,
|
||||
'IntTensor': torch.int,
|
||||
'LongTensor': torch.long,
|
||||
'BoolTensor': torch.bool,
|
||||
}
|
||||
|
||||
|
||||
class _MyTensor(Tensor):
|
||||
"""This class is only for correctness verification.
|
||||
"""
|
||||
_pre_op_fn: Callable[['LazyTensor'], None] = lambda *args: None
|
||||
|
||||
def __new__(cls, func, *args, dtype=None, device=None, **kwargs) -> '_MyTensor':
|
||||
def __new__(cls, func, *args, concrete_data=None, **kwargs) -> '_MyTensor':
|
||||
cls._pre_op_fn()
|
||||
data = func(*args, dtype=dtype, device=device, **kwargs)
|
||||
if concrete_data is not None:
|
||||
# uniform api as LazyTensor
|
||||
data = concrete_data
|
||||
else:
|
||||
data = func(*args, **kwargs)
|
||||
return Tensor._make_subclass(cls, data, require_grad=data.requires_grad)
|
||||
|
||||
@classmethod
|
||||
|
@ -66,11 +87,13 @@ class LazyTensor(torch.Tensor):
|
|||
>>> x.add_(1) # modifying origin tensor after cloning leads to wrong materialization
|
||||
>>> z = x.tolist()
|
||||
>>> x.zeros_() # modifying origin tensor after cloning tolist is not allowed
|
||||
>>> x.data = torch.rand(2, 3) # directly set data of a lazy tensor is not allowed
|
||||
>>> nn.utils.weight_norm(self.conv, name="weight", dim=2) # applying weight norm on a lazy tensor is not allowed
|
||||
|
||||
|
||||
2. Cases that ``LazyTensor`` becomes eager (early materialization).
|
||||
>>> b = a[:, 2:] # get a slice of a lazy tensor triggers early materialization
|
||||
>>> chunks = a.split(3) # this also triggers early materialization
|
||||
>>> x.data = torch.rand(2, 3) # directly setting data of a lazy tensor triggers early materialization
|
||||
|
||||
"""
|
||||
|
||||
|
@ -79,12 +102,16 @@ class LazyTensor(torch.Tensor):
|
|||
_pre_op_fn: Callable[['LazyTensor'], None] = lambda *args: None
|
||||
|
||||
@staticmethod
|
||||
def __new__(cls, func, *args, meta_data=None, **kwargs):
|
||||
if meta_data is None:
|
||||
device = kwargs.get('device', 'cpu')
|
||||
elem = func(*args, **{**kwargs, 'device': 'meta'})
|
||||
meta_data = MetaTensor(elem, fake_device=device)
|
||||
elem = meta_data._tensor
|
||||
def __new__(cls, func, *args, meta_data=None, concrete_data=None, **kwargs):
|
||||
if concrete_data is not None:
|
||||
# some ops don't support meta backend and should have concrete data
|
||||
elem = concrete_data
|
||||
else:
|
||||
if meta_data is None:
|
||||
device = kwargs.get('device', 'cpu')
|
||||
elem = func(*args, **{**kwargs, 'device': 'meta'})
|
||||
meta_data = MetaTensor(elem, fake_device=device)
|
||||
elem = meta_data._tensor
|
||||
r = torch.Tensor._make_wrapper_subclass(cls,
|
||||
elem.size(),
|
||||
strides=elem.stride(),
|
||||
|
@ -96,10 +123,10 @@ class LazyTensor(torch.Tensor):
|
|||
r._meta_data = meta_data
|
||||
return r
|
||||
|
||||
def __init__(self, func, *args, meta_data=None, **kwargs):
|
||||
def __init__(self, func, *args, meta_data=None, concrete_data=None, **kwargs):
|
||||
self._factory_method = (func, args, kwargs) # (func, args, kwargs)
|
||||
self._op_buffer = [] # (func, args, kwargs, replace)
|
||||
self._materialized_data: Optional[torch.Tensor] = None # materialized data
|
||||
self._materialized_data: Optional[torch.Tensor] = concrete_data # materialized data
|
||||
|
||||
def materialize(self) -> torch.Tensor:
|
||||
"""Materialize the ``LazyTensor`` to ``torch.Tensor``.
|
||||
|
@ -212,7 +239,7 @@ class LazyTensor(torch.Tensor):
|
|||
if isinstance(x, LazyTensor):
|
||||
if x._materialized_data is not None:
|
||||
# for early materialized tensor, use its materialized data directly
|
||||
return x._materialized_data
|
||||
return x._materialized_data.data
|
||||
t = x if is_inplace else x.clone()
|
||||
t._op_buffer.append((func, args, kwargs))
|
||||
meta = x._meta_data.data
|
||||
|
@ -232,13 +259,10 @@ class LazyTensor(torch.Tensor):
|
|||
return lazy_y
|
||||
elif type(y) is Tensor:
|
||||
# for early materialized tensor
|
||||
with torch._C.DisableTorchFunction():
|
||||
meta = MetaTensor(y.new_empty(y.shape, dtype=y.dtype, device='meta'), fake_device=y.device)
|
||||
lazy_y = LazyTensor(lambda: None, meta_data=meta)
|
||||
lazy_y._materialized_data = y
|
||||
return lazy_y
|
||||
return LazyTensor(lambda: None, concrete_data=y)
|
||||
return y
|
||||
|
||||
cls._pre_op_fn()
|
||||
o = func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs))
|
||||
if isinstance(o, (tuple, list)):
|
||||
return type(o)(wrap(y, i=i) for i, y in enumerate(o))
|
||||
|
@ -266,7 +290,10 @@ class LazyTensor(torch.Tensor):
|
|||
|
||||
@data.setter
|
||||
def data(self, other: 'LazyTensor'):
|
||||
raise NotImplementedError
|
||||
if other is self:
|
||||
return
|
||||
# TODO(ver217): to avoid infinity recursion, do early materialization
|
||||
self._materialized_data = other._materialize_data()
|
||||
|
||||
def tolist(self) -> list:
|
||||
t = self.materialize()
|
||||
|
@ -330,18 +357,61 @@ class LazyInitContext:
|
|||
|
||||
return wrapper, target
|
||||
|
||||
def wrap_legacy_constructor(target, dtype):
|
||||
# legacy constructor (e.g. torch.LongTensor())
|
||||
def wrapper(*args, **kwargs):
|
||||
if len(args) == 1 and isinstance(args[0], torch.Tensor):
|
||||
# (Tensor other)
|
||||
return args[0]
|
||||
elif len(args) == 1:
|
||||
# (object data, *, torch.device device)
|
||||
kwargs = {**kwargs, 'dtype': dtype}
|
||||
replaced, orig = self.overrides['tensor']
|
||||
return replaced(*args, **kwargs)
|
||||
elif _is_int_tuple(args):
|
||||
# (tuple of ints size, *, torch.device device)
|
||||
kwargs = {**kwargs, 'dtype': dtype}
|
||||
replaced, orig = self.overrides['empty']
|
||||
return replaced(*args, **kwargs)
|
||||
else:
|
||||
raise TypeError(
|
||||
f'new() received an invalid combination of arguments - got {tuple(type(x) for x in args)}, but expected one of:\n * (Tensor other)\n * (tuple of ints size, *, torch.device device)\n * (object data, *, torch.device device)'
|
||||
)
|
||||
|
||||
return wrapper, target
|
||||
|
||||
def wrap_no_meta_factory(target):
|
||||
# factory functions which don't support meta tensor backend
|
||||
def wrapper(*args, **kwargs):
|
||||
tensor = target(*args, **kwargs)
|
||||
return self.tensor_cls(lambda: None, concrete_data=tensor)
|
||||
|
||||
return wrapper, target
|
||||
|
||||
self.overrides = {
|
||||
target: wrap_factory_method(getattr(torch, target))
|
||||
for target in _TorchFactoryMethod
|
||||
for target in _NORMAL_FACTORY
|
||||
if callable(getattr(torch, target, None))
|
||||
}
|
||||
|
||||
self.overrides.update({
|
||||
target + '_like': wrap_factory_like_method(getattr(torch, target), getattr(torch, target + '_like'))
|
||||
for target in _TorchFactoryMethod
|
||||
for target in _NORMAL_FACTORY
|
||||
if callable(getattr(torch, target + '_like', None))
|
||||
})
|
||||
|
||||
self.overrides.update({
|
||||
target: wrap_legacy_constructor(getattr(torch, target), dtype)
|
||||
for target, dtype in _LEGACY_TENSOR_CONSTRUCTOR.items()
|
||||
if callable(getattr(torch, target, None))
|
||||
})
|
||||
|
||||
self.overrides.update({
|
||||
target: wrap_no_meta_factory(getattr(torch, target))
|
||||
for target in _NO_META_FACTORY
|
||||
if callable(getattr(torch, target, None))
|
||||
})
|
||||
|
||||
for name, (wrapper, orig) in self.overrides.items():
|
||||
setattr(torch, name, wrapper)
|
||||
|
||||
|
@ -363,34 +433,65 @@ class LazyInitContext:
|
|||
param_lazy_cnt = 0
|
||||
buf_cnt = 0
|
||||
buf_lazy_cnt = 0
|
||||
non_lazy_numel = 0
|
||||
|
||||
# do post cleaning to handle shared parameter
|
||||
visited_lazy_tensors: List[LazyTensor] = []
|
||||
# handle shared module
|
||||
visited_modules = set()
|
||||
|
||||
@torch.no_grad()
|
||||
def init_recursively(module: nn.Module):
|
||||
nonlocal param_cnt, param_lazy_cnt, buf_cnt, buf_lazy_cnt
|
||||
nonlocal param_cnt, param_lazy_cnt, buf_cnt, buf_lazy_cnt, non_lazy_numel
|
||||
# recursively initialize the module
|
||||
for mod in module.children():
|
||||
init_recursively(mod)
|
||||
if id(mod) not in visited_modules:
|
||||
visited_modules.add(id(mod))
|
||||
init_recursively(mod)
|
||||
|
||||
# initialize tensors directly attached to the current module
|
||||
for name, param in module.named_parameters(recurse=False):
|
||||
if verbose:
|
||||
param_cnt += 1
|
||||
if param._materialized_data is None:
|
||||
if getattr(param, '_materialized_data', False) is None:
|
||||
# if no _materialized_data attr, the tensor is not lazy
|
||||
param_lazy_cnt += 1
|
||||
setattr(module, name, param.materialize())
|
||||
param.clean()
|
||||
else:
|
||||
non_lazy_numel += param.numel()
|
||||
if hasattr(param, 'materialize'):
|
||||
# TODO(ver217): apex layers cannot be captured
|
||||
visited_lazy_tensors.append(param)
|
||||
setattr(module, name, param.materialize())
|
||||
|
||||
for name, buf in module.named_buffers(recurse=False):
|
||||
if verbose:
|
||||
buf_cnt += 1
|
||||
if buf._materialized_data is None:
|
||||
if getattr(buf, "_materialized_data", False) is None:
|
||||
# if no _materialized_data attr, the tensor is not lazy
|
||||
buf_lazy_cnt += 1
|
||||
setattr(module, name, buf.materialize())
|
||||
buf.clean()
|
||||
else:
|
||||
non_lazy_numel += buf.numel()
|
||||
if hasattr(buf, 'materialize'):
|
||||
# TODO(ver217): apex layers cannot be captured
|
||||
visited_lazy_tensors.append(buf)
|
||||
setattr(module, name, buf.materialize())
|
||||
|
||||
init_recursively(module)
|
||||
|
||||
for t in visited_lazy_tensors:
|
||||
t.clean()
|
||||
|
||||
if verbose:
|
||||
print(f'Param lazy rate: {param_lazy_cnt}/{param_cnt}')
|
||||
print(f'Buffer lazy rate: {buf_lazy_cnt}/{buf_cnt}')
|
||||
print(f'Non-lazy numel: {non_lazy_numel} ({non_lazy_numel/1024**2:.3f} M)')
|
||||
return module
|
||||
|
||||
|
||||
def _is_int_tuple(args) -> bool:
|
||||
if not isinstance(args, tuple):
|
||||
return False
|
||||
for x in args:
|
||||
if not isinstance(x, int):
|
||||
return False
|
||||
return True
|
||||
|
|
|
@ -0,0 +1 @@
|
|||
from .torchrec import *
|
|
@ -0,0 +1,23 @@
|
|||
import pytest
|
||||
|
||||
from tests.kit.model_zoo import model_zoo
|
||||
|
||||
# FIXME(ver217): uncomment this line
|
||||
# from utils import check_lazy_init
|
||||
|
||||
|
||||
# FIXME(ver217): temporarily skip this test since torch 1.11 does not fully support meta tensor
|
||||
@pytest.mark.skip
|
||||
@pytest.mark.parametrize('subset', ['torchvision', 'diffusers', 'timm', 'transformers', 'torchaudio', 'deepfm', 'dlrm'])
|
||||
def test_torchvision_models_lazy_init(subset):
|
||||
sub_model_zoo = model_zoo.get_sub_registry(subset)
|
||||
for name, entry in sub_model_zoo.items():
|
||||
# TODO(ver217): lazy init does not support weight norm, skip these models
|
||||
if name in ('torchaudio_wav2vec2_base', 'torchaudio_hubert_base'):
|
||||
continue
|
||||
# FIXME(ver217): uncomment this line
|
||||
# check_lazy_init(entry, verbose=True)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_torchvision_models_lazy_init('torchvision')
|
|
@ -0,0 +1,69 @@
|
|||
import random
|
||||
from typing import Any, Callable, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from colossalai.utils.model.experimental import LazyInitContext, LazyTensor, _MyTensor
|
||||
from tests.kit.model_zoo.registry import ModelAttribute
|
||||
|
||||
# model_fn, data_gen_fn, output_transform_fn, model_attr
|
||||
TestingEntry = Tuple[Callable[[], torch.nn.Module], Callable[[], dict], Callable[[], dict], Optional[ModelAttribute]]
|
||||
|
||||
|
||||
def set_seed(seed: int) -> None:
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
|
||||
|
||||
def assert_model_eqaual(m1: torch.nn.Module, m2: torch.nn.Module) -> None:
|
||||
s1 = m1.state_dict()
|
||||
s2 = m2.state_dict()
|
||||
|
||||
assert len(s1) == len(s2), f'len {len(s1)} vs {len(s2)}'
|
||||
|
||||
for (n1, t1), (n2, t2) in zip(s1.items(), s2.items()):
|
||||
assert n1 == n2
|
||||
assert torch.equal(t1, t2), f'{n1} {t1} vs {t2}'
|
||||
|
||||
|
||||
def assert_forward_equal(m1: torch.nn.Module, m2: torch.nn.Module, data_gen_fn: Callable[[], dict],
|
||||
output_transform_fn: Callable[[Any], dict]) -> None:
|
||||
data = data_gen_fn()
|
||||
|
||||
m1.eval()
|
||||
m2.eval()
|
||||
# run forward
|
||||
with torch.no_grad():
|
||||
outputs1 = m1(**data)
|
||||
outputs2 = m2(**data)
|
||||
|
||||
# compare output
|
||||
transformed_out1 = output_transform_fn(outputs1)
|
||||
transformed_out2 = output_transform_fn(outputs2)
|
||||
|
||||
assert len(transformed_out1) == len(transformed_out2)
|
||||
|
||||
for key, out1 in transformed_out1.items():
|
||||
out2 = transformed_out2[key]
|
||||
assert torch.allclose(out1, out2, atol=1e-5), \
|
||||
f'{m1.__class__.__name__} has inconsistent outputs, {out1} vs {out2}'
|
||||
|
||||
|
||||
def check_lazy_init(entry: TestingEntry, seed: int = 42, verbose: bool = False, check_forward: bool = False) -> None:
|
||||
model_fn, data_gen_fn, output_transform_fn, model_attr = entry
|
||||
_MyTensor._pre_op_fn = lambda *args: set_seed(seed)
|
||||
LazyTensor._pre_op_fn = lambda *args: set_seed(seed)
|
||||
ctx = LazyInitContext(tensor_cls=_MyTensor)
|
||||
with ctx:
|
||||
model = model_fn()
|
||||
ctx = LazyInitContext()
|
||||
with ctx:
|
||||
deferred_model = model_fn()
|
||||
deferred_model = ctx.materialize(deferred_model, verbose=verbose)
|
||||
assert_model_eqaual(model, deferred_model)
|
||||
if check_forward:
|
||||
assert_forward_equal(model, deferred_model, data_gen_fn, output_transform_fn)
|
||||
if verbose:
|
||||
print(f'{model.__class__.__name__} pass')
|
Loading…
Reference in New Issue