[lazyinit] refactor lazy tensor and lazy init ctx (#3131)

* [lazyinit] refactor lazy tensor and lazy init ctx

* [lazyinit] polish docstr

* [lazyinit] polish docstr
pull/3135/head
ver217 2023-03-14 15:37:12 +08:00 committed by GitHub
parent 86ac782d7c
commit ed8f60b93b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 177 additions and 221 deletions

View File

@ -1,17 +1,11 @@
import contextlib from typing import Callable, Optional, Union
import copy
import gc
import pprint
from typing import Callable, List, Optional, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch import Tensor
from torch.utils._pytree import tree_map from torch.utils._pytree import tree_map
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx.profiler import MetaTensor from colossalai.fx.profiler import MetaTensor
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
from colossalai.tensor.sharding_spec import ShardingSpec
# reference: https://pytorch.org/cppdocs/notes/tensor_creation.html # reference: https://pytorch.org/cppdocs/notes/tensor_creation.html
_TorchFactoryMethod = [ _TorchFactoryMethod = [
@ -30,9 +24,23 @@ _TorchFactoryMethod = [
"tensor", "tensor",
] ]
orig_empty = torch.empty # avoid override _EARLY_MATERIALIZED_OPS = ['__getitem__', 'split']
scm = ShapeConsistencyManager()
class _MyTensor(Tensor):
"""This class is only for correctness verification.
"""
_pre_op_fn: Callable[['LazyTensor'], None] = lambda *args: None
def __new__(cls, func, *args, dtype=None, device=None, **kwargs) -> '_MyTensor':
cls._pre_op_fn()
data = func(*args, dtype=dtype, device=device, **kwargs)
return Tensor._make_subclass(cls, data, require_grad=data.requires_grad)
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
cls._pre_op_fn()
return super().__torch_function__(func, types, args, kwargs)
class LazyTensor(torch.Tensor): class LazyTensor(torch.Tensor):
@ -50,140 +58,114 @@ class LazyTensor(torch.Tensor):
tensor([[0., 1., 1.], tensor([[0., 1., 1.],
[1., 1., 1.]], device='cuda:0', dtype=torch.float16) [1., 1., 1.]], device='cuda:0', dtype=torch.float16)
2. Generate ``MetaTensor`` from ``LazyTensor``
>>> x = LazyTensor(torch.zeros, 2, 3)
>>> x.reshape(3, 2)
>>> x = x.traceable() # generate ``MetaTensor``
>>> print(x)
MetaTensor(..., size=(3, 2), device=cpu, dtype=torch.float32)
3. Use ``LazyTensor`` to generate sharded ``nn.Parameter``.
>>> x = LazyTensor(torch.zeros, 2, 3)
>>> x.spec = ... # some ``ShardingSpec``
>>> x.distribute() # distribute the tensor according to the ``ShardingSpec``
Warnings: Warnings:
1. Cases that ``LazyTensor`` can't deal with. 1. Cases that ``LazyTensor`` can't deal with.
>>> x = LazyTensor(torch.ones, 2, 3) >>> x = LazyTensor(torch.ones, 2, 3)
>>> x[0, 0] = -x[0, 0] # this will cause infinite recursion >>> x[0, 0] = -x[0, 0] # this will cause infinite recursion
>>> y = x.clone()
>>> x.add_(1) # modifying origin tensor after cloning leads to wrong materialization
>>> z = x.tolist()
>>> x.zeros_() # modifying origin tensor after cloning tolist is not allowed
>>> x.data = torch.rand(2, 3) # directly set data of a lazy tensor is not allowed
2. Cases that ``LazyTensor`` becomes eager (early materialization).
>>> b = a[:, 2:] # get a slice of a lazy tensor triggers early materialization
>>> chunks = a.split(3) # this also triggers early materialization
2. ``LazyTensor.materialize()`` can't be called multiple times.
>>> x = LazyTensor(torch.ones, 2, 3)
>>> x.materialize()
>>> x.materialize() # this is disallowed
""" """
_repr = True _repr = True
_meta_data: Optional[MetaTensor] = None # shape, dtype, device _meta_data: Optional[MetaTensor] = None # shape, dtype, device
_cached_data: Optional[torch.Tensor] = None # materialized data _pre_op_fn: Callable[['LazyTensor'], None] = lambda *args: None
@staticmethod @staticmethod
def __new__(cls, func, *args, dtype=None, device=None, **kwargs): def __new__(cls, func, *args, meta_data=None, **kwargs):
elem = func(*args, dtype=dtype, device='meta', **kwargs) if meta_data is None:
device = kwargs.get('device', 'cpu')
elem = func(*args, **{**kwargs, 'device': 'meta'})
meta_data = MetaTensor(elem, fake_device=device)
elem = meta_data._tensor
r = torch.Tensor._make_wrapper_subclass(cls, r = torch.Tensor._make_wrapper_subclass(cls,
elem.size(), elem.size(),
strides=elem.stride(), strides=elem.stride(),
storage_offset=elem.storage_offset(), storage_offset=elem.storage_offset(),
dtype=elem.dtype, dtype=elem.dtype,
layout=elem.layout, layout=elem.layout,
device=device if device is not None else torch.device('cpu'), device=elem.device,
requires_grad=elem.requires_grad) requires_grad=elem.requires_grad)
r._meta_data = MetaTensor(elem, fake_device=device) r._meta_data = meta_data
return r return r
def __init__(self, func, *args, dtype=None, device=None, **kwargs): def __init__(self, func, *args, meta_data=None, **kwargs):
self._factory_method = (func, args, {'dtype': dtype, 'device': device, **kwargs}) # (func, args, kwargs) self._factory_method = (func, args, kwargs) # (func, args, kwargs)
self._cached_buffer = list() # (func, args, kwargs) self._op_buffer = [] # (func, args, kwargs, replace)
self._spec = None self._materialized_data: Optional[torch.Tensor] = None # materialized data
self._data = self
def __repr__(self):
if self._repr:
# avoid recursive representation
self.__class__._repr = False
s = f'LazyTensor(..., size={tuple(self._meta_data.shape)}, device={self._meta_data.device}, dtype={self._meta_data.dtype})\n'\
f'factory method: {self._factory_method}\n'\
f'cached: {pprint.pformat(self._cached_buffer) if self._cached_data is None else self._cached_data}\n'\
f'spec: {self._spec}'
self.__class__._repr = True
return s
else:
return 'LazyTensor(...)'
def materialize(self) -> torch.Tensor: def materialize(self) -> torch.Tensor:
"""Materialize the ``LazyTensor`` to ``torch.Tensor``. """Materialize the ``LazyTensor`` to ``torch.Tensor``.
Warnings:
Calling ``self.materialize()`` will clear all cached sequence and factory method,
because we don't allow materialize the same ``LazyTensor`` twice.
This is mentioned in the paper: https://arxiv.org/pdf/2102.13267.pdf (Part 4.3).
Returns: Returns:
torch.Tensor: The materialized tensor. torch.Tensor: The materialized tensor.
""" """
target = self._data._realize_cached_data() target = self._materialize_data()
if isinstance(self, nn.Parameter): if isinstance(self, nn.Parameter):
target = nn.Parameter(target, requires_grad=self.requires_grad) target = nn.Parameter(target, requires_grad=self.requires_grad)
self._clear_all()
return target return target
def traceable(self) -> MetaTensor: def clean(self) -> None:
"""Generate ``MetaTensor`` from ``LazyTensor``. (Mostly for tracing) """Clean all stored operations, meta data and materialized data, which prevents memory leaking. This should be called after all tensors are materialized.
Returns:
MetaTensor: The generated ``MetaTensor``.
""" """
if isinstance(self, nn.Parameter): self._factory_method = None
return nn.Parameter(self._meta_data, requires_grad=self.requires_grad) self._op_buffer = None
else: self._materialized_data = None
return self._meta_data self._meta_data = None
def distribute(self) -> torch.Tensor: @staticmethod
"""Distribute the ``LazyTensor`` according to the ``ShardingSpec``. def _replace_with_materialized(x):
if isinstance(x, LazyTensor):
return x._materialize_data()
return x
Returns: def _materialize_data(self) -> torch.Tensor:
torch.Tensor: The sharded tensor. # self._materialized_data should be generated after the first call of this function
if self._materialized_data is None:
# apply factory method
func, args, kwargs = self._factory_method
# apply cached sequence
self._pre_op_fn()
try:
init_val = func(*tree_map(self._replace_with_materialized, args),
**tree_map(self._replace_with_materialized, kwargs))
except TypeError as e:
print(f'init fn: {func.__name__}')
raise e
self._materialized_data = self._rerun_ops(init_val)
return self._materialized_data
def _rerun_ops(self, target=None) -> torch.Tensor:
"""Do lazy execution by rerunning all (stored) related operations.
Args:
target (torc.Tensor, optional): Intial value of the target tensor (self). Defaults to None.
""" """
if self._spec is None:
raise RuntimeError('ShardingSpec is not set for\n{self}')
spec, device_mesh = self._spec, self._spec.device_mesh
target = self.materialize()
# TODO(some man): better not be coupled with auto-parallel
target.data = scm.apply_for_autoparallel_runtime(target.data, ShardingSpec(device_mesh, target.shape, {}),
spec).detach().clone()
return target
def _realize_cached_data(self) -> torch.Tensor:
# self._cached_data should be generated after the first call of this function
if self._cached_data is None:
if self._factory_method is not None:
# apply factory method
func, args, kwargs = self._factory_method
# apply cached sequence
self._cached_data = self._apply_cache_buffer(func(*args, **kwargs))
else:
# apply cached sequence only
self._cached_data = self._apply_cache_buffer()
return self._cached_data
def _apply_cache_buffer(self, target=None) -> torch.Tensor:
# dump all cached sequence
# super-dainiu: support methods for single Tensor only
def replace(x): def replace(x):
if x is self: if x is self:
return target return target
elif isinstance(x, LazyTensor): elif isinstance(x, LazyTensor):
return x._realize_cached_data() return x._materialize_data()
return x return x
packed = None packed = None
for (func, args, kwargs) in self._cached_buffer: for (func, args, kwargs) in self._op_buffer:
if func == torch.Tensor.requires_grad_: if func == torch.Tensor.requires_grad_:
packed = func, args, kwargs # requires grad should be set at last packed = func, args, kwargs # requires grad should be set at last
else: else:
self._pre_op_fn()
o = func(*tree_map(replace, args), **tree_map(replace, kwargs)) o = func(*tree_map(replace, args), **tree_map(replace, kwargs))
target = o if isinstance(o, torch.Tensor) else target # if func returns non-Tensor, discard the value target = o if isinstance(o, torch.Tensor) else target # if func returns non-Tensor, discard the value
@ -194,24 +176,23 @@ class LazyTensor(torch.Tensor):
return target return target
# clear all means:
# 1. clear factory method
# 2. clear cached sequence
# 3. clear cached data
def _clear_all(self):
self._cached_data = None
self._cached_buffer = None
self._data = None
gc.collect() # avoid memory leak
# cache everything with __torch_function__ # cache everything with __torch_function__
@classmethod @classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None): def __torch_function__(cls, func, types, args=(), kwargs=None):
if kwargs is None: if kwargs is None:
kwargs = {} kwargs = {}
target = None if func.__name__ in _EARLY_MATERIALIZED_OPS:
# These OPs cannot be lazy and related tensors should be early materialized
tree_map(cls._replace_with_materialized, args)
tree_map(cls._replace_with_materialized, kwargs)
is_inplace: bool = (func.__name__.endswith('_') and not (func.__name__.endswith('__'))
or func.__name__ == "__setitem__")
if isinstance(func, torch._C.ScriptMethod): if isinstance(func, torch._C.ScriptMethod):
# FIXME(ver217): torch script functions are not verified
target = None
def unwrap(x): def unwrap(x):
if isinstance(x, LazyTensor): if isinstance(x, LazyTensor):
@ -219,79 +200,83 @@ class LazyTensor(torch.Tensor):
return x return x
target: LazyTensor = args[0].clone() target: LazyTensor = args[0].clone()
target._cached_buffer.append((func, args, kwargs)) target._op_buffer.append((func, args, kwargs))
target._meta_data = getattr(target._meta_data, func.name)(*tree_map(unwrap, args[1:]), target._meta_data = getattr(target._meta_data, func.name)(*tree_map(unwrap, args[1:]),
**tree_map(unwrap, kwargs)) **tree_map(unwrap, kwargs))
else:
def unwrap(x):
nonlocal target
if isinstance(x, LazyTensor):
target = x if (func.__name__.endswith('_') and not (func.__name__.endswith('__'))
or func.__name__ == "__setitem__") else x.clone()
target._cached_buffer.append((func, args, kwargs))
return x._meta_data
return x
args = tree_map(unwrap, args)
kwargs = tree_map(unwrap, kwargs)
o = func(*args, **kwargs)
if isinstance(o, MetaTensor):
target._meta_data = o
return target return target
else: else:
return o
meta_to_lazy = {}
def unwrap(x):
if isinstance(x, LazyTensor):
if x._materialized_data is not None:
# for early materialized tensor, use its materialized data directly
return x._materialized_data
t = x if is_inplace else x.clone()
t._op_buffer.append((func, args, kwargs))
meta = x._meta_data.data
meta_to_lazy[meta] = t
return meta
return x
def wrap(y, i=None):
if isinstance(y, MetaTensor):
if y in meta_to_lazy:
# inplace op, just return origin lazy tensor
return meta_to_lazy[y]
else:
# out of place op, create new lazy tensor
fn = lambda *a, **kw: func(*a, **kw) if i is None else func(*a, **kw)[i]
lazy_y = LazyTensor(fn, *args, meta_data=y, **kwargs)
return lazy_y
elif type(y) is Tensor:
# for early materialized tensor
with torch._C.DisableTorchFunction():
meta = MetaTensor(y.new_empty(y.shape, dtype=y.dtype, device='meta'), fake_device=y.device)
lazy_y = LazyTensor(lambda: None, meta_data=meta)
lazy_y._materialized_data = y
return lazy_y
return y
o = func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs))
if isinstance(o, (tuple, list)):
return type(o)(wrap(y, i=i) for i, y in enumerate(o))
return wrap(o)
@classmethod @classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None): def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
pass # skip pass # skip
def clone(self) -> "LazyTensor": def clone(self) -> "LazyTensor":
"""Create a new ``LazyTensor`` with same cached sequence and factory method.
Returns: def factory_fn():
LazyTensor: the new ``LazyTensor`` return self.materialize().clone()
"""
target = LazyTensor(orig_empty, 0, dtype=self._meta_data.dtype, device=self._meta_data.device) target = LazyTensor(factory_fn, meta_data=self._meta_data)
target._factory_method = None
target._cached_buffer = list()
target._meta_data = self._meta_data.clone()
target._cached_data = self._cached_data.clone() if self._cached_data is not None else None
target._spec = copy.deepcopy(self._spec)
return target return target
def detach(self) -> "LazyTensor": def detach(self) -> Tensor:
target = self.clone() return self
target._cached_buffer.append((torch.Tensor.detach_, (self,), {}))
return target
@property @property
def spec(self) -> ShardingSpec: def data(self):
return self._spec return self
@spec.setter
def spec(self, other: ShardingSpec):
self._spec = other
@property
def data(self) -> "LazyTensor":
return self._data.detach()
@data.setter @data.setter
def data(self, other: "LazyTensor") -> "LazyTensor": def data(self, other: 'LazyTensor'):
"""This avoid the following infinite recursion, which is very common in ``nn.Module`` initialization. raise NotImplementedError
Usage: def tolist(self) -> list:
>>> a = LazyTensor(torch.empty, 0, dtype=torch.float32, device='cpu') t = self.materialize()
>>> b = a.cuda() return t.tolist()
>>> a.data = b
""" def __hash__(self):
self._data = other return id(self)
class LazyInitContext(): class LazyInitContext:
"""Context manager for lazy initialization. Enables initializing the model without allocating real memory. """Context manager for lazy initialization. Enables initializing the model without allocating real memory.
Usage: Usage:
@ -319,16 +304,21 @@ class LazyInitContext():
1. Quantization strategies can be applied before allocating real memory. 1. Quantization strategies can be applied before allocating real memory.
2. Lazy initialization seems slower than normal initialization. 2. Lazy initialization seems slower than normal initialization.
""" """
_replaced: bool = False
def __init__(self): def __init__(self, tensor_cls: Union[_MyTensor, LazyTensor] = LazyTensor):
self.overrides = {} self.overrides = {}
self.tensor_cls = tensor_cls
def __enter__(self): def __enter__(self):
if LazyInitContext._replaced:
raise RuntimeError(f'LazyInitContext is not reentrant')
LazyInitContext._replaced = True
def wrap_factory_method(target): def wrap_factory_method(target):
# factory functions (eg. torch.empty()) # factory functions (eg. torch.empty())
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
return LazyTensor(target, *args, **kwargs) return self.tensor_cls(target, *args, **kwargs)
return wrapper, target return wrapper, target
@ -336,7 +326,7 @@ class LazyInitContext():
# factory_like functions (eg. torch.empty_like()) # factory_like functions (eg. torch.empty_like())
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
orig_t = args[0] orig_t = args[0]
return LazyTensor(orig_target, *args[1:], device=orig_t.device, dtype=orig_t.dtype, **kwargs) return self.tensor_cls(orig_target, *args[1:], device=orig_t.device, dtype=orig_t.dtype, **kwargs)
return wrapper, target return wrapper, target
@ -356,85 +346,51 @@ class LazyInitContext():
setattr(torch, name, wrapper) setattr(torch, name, wrapper)
def __exit__(self, exc_type, exc_val, exc_tb): def __exit__(self, exc_type, exc_val, exc_tb):
LazyInitContext._replaced = False
for name, (wrapper, orig) in self.overrides.items(): for name, (wrapper, orig) in self.overrides.items():
setattr(torch, name, orig) setattr(torch, name, orig)
@staticmethod @staticmethod
def materialize(module: torch.nn.Module): def materialize(module: torch.nn.Module, verbose: bool = False):
"""Initialize all ``nn.Parameter`` from ``LazyTensor``. """Initialize all ``nn.Parameter`` from ``LazyTensor``.
Args: Args:
module (torch.nn.Module): Target ``nn.Module`` module (torch.nn.Module): Target ``nn.Module``
verbose (bool): Whether to print lazy initialization rate. Defaults to False.
""" """
if verbose:
param_cnt = 0
param_lazy_cnt = 0
buf_cnt = 0
buf_lazy_cnt = 0
@torch.no_grad() @torch.no_grad()
def init_recursively(module: nn.Module): def init_recursively(module: nn.Module):
nonlocal param_cnt, param_lazy_cnt, buf_cnt, buf_lazy_cnt
# recursively initialize the module # recursively initialize the module
for mod in module.children(): for mod in module.children():
init_recursively(mod) init_recursively(mod)
# initialize tensors directly attached to the current module # initialize tensors directly attached to the current module
for name, param in module.named_parameters(recurse=False): for name, param in module.named_parameters(recurse=False):
if verbose:
param_cnt += 1
if param._materialized_data is None:
param_lazy_cnt += 1
setattr(module, name, param.materialize()) setattr(module, name, param.materialize())
param.clean()
for name, buf in module.named_buffers(recurse=False): for name, buf in module.named_buffers(recurse=False):
if verbose:
buf_cnt += 1
if buf._materialized_data is None:
buf_lazy_cnt += 1
setattr(module, name, buf.materialize()) setattr(module, name, buf.materialize())
buf.clean()
init_recursively(module) init_recursively(module)
if verbose:
print(f'Param lazy rate: {param_lazy_cnt}/{param_cnt}')
print(f'Buffer lazy rate: {buf_lazy_cnt}/{buf_cnt}')
return module return module
@staticmethod
def distribute(module: torch.nn.Module):
"""Initialize and shard all ``nn.Parameter`` from ``LazyTensor``.
Args:
module (torch.nn.Module): Sharded target ``nn.Module``
"""
@torch.no_grad()
def init_recursively(module: nn.Module):
# recursively initialize the module
for mod in module.children():
init_recursively(mod)
# initialize tensors directly attached to the current module
for name, param in module.named_parameters(recurse=False):
setattr(module, name, param.distribute())
for name, buf in module.named_buffers(recurse=False):
setattr(module, name, buf.distribute())
init_recursively(module)
return module
@staticmethod
@contextlib.contextmanager
def traceable(module: torch.nn.Module):
"""Initialize all ``nn.Parameters`` as ``MetaTensor``. This enables ``ColoTracer`` with control flow.
Args:
module (torch.nn.Module): Traceable ``nn.Module`` with ``MetaTensor`` as parameters.
"""
orig_val = dict()
def init_recursively(module: nn.Module):
# recursively initialize the module
for mod in module.children():
init_recursively(mod)
# initialize tensors directly attached to the current module
for name, param in module.named_parameters(recurse=False):
setattr(module, name, param.traceable())
orig_val[(module, name)] = param
for name, buf in module.named_buffers(recurse=False):
setattr(module, name, buf.traceable())
orig_val[(module, name)] = buf
init_recursively(module)
yield
# restore original values
for (module, name), val in orig_val.items():
setattr(module, name, val)