diff --git a/colossalai/utils/model/experimental.py b/colossalai/utils/model/experimental.py index 8291227b7..b8eb742f8 100644 --- a/colossalai/utils/model/experimental.py +++ b/colossalai/utils/model/experimental.py @@ -1,17 +1,11 @@ -import contextlib -import copy -import gc -import pprint -from typing import Callable, List, Optional, Union +from typing import Callable, Optional, Union import torch import torch.nn as nn +from torch import Tensor from torch.utils._pytree import tree_map -from colossalai.device.device_mesh import DeviceMesh from colossalai.fx.profiler import MetaTensor -from colossalai.tensor.shape_consistency import ShapeConsistencyManager -from colossalai.tensor.sharding_spec import ShardingSpec # reference: https://pytorch.org/cppdocs/notes/tensor_creation.html _TorchFactoryMethod = [ @@ -30,9 +24,23 @@ _TorchFactoryMethod = [ "tensor", ] -orig_empty = torch.empty # avoid override +_EARLY_MATERIALIZED_OPS = ['__getitem__', 'split'] -scm = ShapeConsistencyManager() + +class _MyTensor(Tensor): + """This class is only for correctness verification. + """ + _pre_op_fn: Callable[['LazyTensor'], None] = lambda *args: None + + def __new__(cls, func, *args, dtype=None, device=None, **kwargs) -> '_MyTensor': + cls._pre_op_fn() + data = func(*args, dtype=dtype, device=device, **kwargs) + return Tensor._make_subclass(cls, data, require_grad=data.requires_grad) + + @classmethod + def __torch_function__(cls, func, types, args=(), kwargs=None): + cls._pre_op_fn() + return super().__torch_function__(func, types, args, kwargs) class LazyTensor(torch.Tensor): @@ -50,140 +58,114 @@ class LazyTensor(torch.Tensor): tensor([[0., 1., 1.], [1., 1., 1.]], device='cuda:0', dtype=torch.float16) - 2. Generate ``MetaTensor`` from ``LazyTensor`` - >>> x = LazyTensor(torch.zeros, 2, 3) - >>> x.reshape(3, 2) - >>> x = x.traceable() # generate ``MetaTensor`` - >>> print(x) - MetaTensor(..., size=(3, 2), device=cpu, dtype=torch.float32) - - 3. Use ``LazyTensor`` to generate sharded ``nn.Parameter``. - >>> x = LazyTensor(torch.zeros, 2, 3) - >>> x.spec = ... # some ``ShardingSpec`` - >>> x.distribute() # distribute the tensor according to the ``ShardingSpec`` - Warnings: 1. Cases that ``LazyTensor`` can't deal with. >>> x = LazyTensor(torch.ones, 2, 3) >>> x[0, 0] = -x[0, 0] # this will cause infinite recursion + >>> y = x.clone() + >>> x.add_(1) # modifying origin tensor after cloning leads to wrong materialization + >>> z = x.tolist() + >>> x.zeros_() # modifying origin tensor after cloning tolist is not allowed + >>> x.data = torch.rand(2, 3) # directly set data of a lazy tensor is not allowed + + 2. Cases that ``LazyTensor`` becomes eager (early materialization). + >>> b = a[:, 2:] # get a slice of a lazy tensor triggers early materialization + >>> chunks = a.split(3) # this also triggers early materialization - 2. ``LazyTensor.materialize()`` can't be called multiple times. - >>> x = LazyTensor(torch.ones, 2, 3) - >>> x.materialize() - >>> x.materialize() # this is disallowed """ _repr = True _meta_data: Optional[MetaTensor] = None # shape, dtype, device - _cached_data: Optional[torch.Tensor] = None # materialized data + _pre_op_fn: Callable[['LazyTensor'], None] = lambda *args: None @staticmethod - def __new__(cls, func, *args, dtype=None, device=None, **kwargs): - elem = func(*args, dtype=dtype, device='meta', **kwargs) + def __new__(cls, func, *args, meta_data=None, **kwargs): + if meta_data is None: + device = kwargs.get('device', 'cpu') + elem = func(*args, **{**kwargs, 'device': 'meta'}) + meta_data = MetaTensor(elem, fake_device=device) + elem = meta_data._tensor r = torch.Tensor._make_wrapper_subclass(cls, elem.size(), strides=elem.stride(), storage_offset=elem.storage_offset(), dtype=elem.dtype, layout=elem.layout, - device=device if device is not None else torch.device('cpu'), + device=elem.device, requires_grad=elem.requires_grad) - r._meta_data = MetaTensor(elem, fake_device=device) + r._meta_data = meta_data return r - def __init__(self, func, *args, dtype=None, device=None, **kwargs): - self._factory_method = (func, args, {'dtype': dtype, 'device': device, **kwargs}) # (func, args, kwargs) - self._cached_buffer = list() # (func, args, kwargs) - self._spec = None - self._data = self - - def __repr__(self): - if self._repr: - # avoid recursive representation - self.__class__._repr = False - s = f'LazyTensor(..., size={tuple(self._meta_data.shape)}, device={self._meta_data.device}, dtype={self._meta_data.dtype})\n'\ - f'factory method: {self._factory_method}\n'\ - f'cached: {pprint.pformat(self._cached_buffer) if self._cached_data is None else self._cached_data}\n'\ - f'spec: {self._spec}' - self.__class__._repr = True - return s - else: - return 'LazyTensor(...)' + def __init__(self, func, *args, meta_data=None, **kwargs): + self._factory_method = (func, args, kwargs) # (func, args, kwargs) + self._op_buffer = [] # (func, args, kwargs, replace) + self._materialized_data: Optional[torch.Tensor] = None # materialized data def materialize(self) -> torch.Tensor: """Materialize the ``LazyTensor`` to ``torch.Tensor``. - Warnings: - Calling ``self.materialize()`` will clear all cached sequence and factory method, - because we don't allow materialize the same ``LazyTensor`` twice. - This is mentioned in the paper: https://arxiv.org/pdf/2102.13267.pdf (Part 4.3). - Returns: torch.Tensor: The materialized tensor. """ - target = self._data._realize_cached_data() + target = self._materialize_data() if isinstance(self, nn.Parameter): target = nn.Parameter(target, requires_grad=self.requires_grad) - self._clear_all() return target - def traceable(self) -> MetaTensor: - """Generate ``MetaTensor`` from ``LazyTensor``. (Mostly for tracing) - - Returns: - MetaTensor: The generated ``MetaTensor``. + def clean(self) -> None: + """Clean all stored operations, meta data and materialized data, which prevents memory leaking. This should be called after all tensors are materialized. """ - if isinstance(self, nn.Parameter): - return nn.Parameter(self._meta_data, requires_grad=self.requires_grad) - else: - return self._meta_data + self._factory_method = None + self._op_buffer = None + self._materialized_data = None + self._meta_data = None - def distribute(self) -> torch.Tensor: - """Distribute the ``LazyTensor`` according to the ``ShardingSpec``. + @staticmethod + def _replace_with_materialized(x): + if isinstance(x, LazyTensor): + return x._materialize_data() + return x - Returns: - torch.Tensor: The sharded tensor. + def _materialize_data(self) -> torch.Tensor: + # self._materialized_data should be generated after the first call of this function + if self._materialized_data is None: + # apply factory method + func, args, kwargs = self._factory_method + + # apply cached sequence + self._pre_op_fn() + + try: + init_val = func(*tree_map(self._replace_with_materialized, args), + **tree_map(self._replace_with_materialized, kwargs)) + except TypeError as e: + print(f'init fn: {func.__name__}') + raise e + + self._materialized_data = self._rerun_ops(init_val) + return self._materialized_data + + def _rerun_ops(self, target=None) -> torch.Tensor: + """Do lazy execution by rerunning all (stored) related operations. + + Args: + target (torc.Tensor, optional): Intial value of the target tensor (self). Defaults to None. """ - if self._spec is None: - raise RuntimeError('ShardingSpec is not set for\n{self}') - spec, device_mesh = self._spec, self._spec.device_mesh - target = self.materialize() - # TODO(some man): better not be coupled with auto-parallel - target.data = scm.apply_for_autoparallel_runtime(target.data, ShardingSpec(device_mesh, target.shape, {}), - spec).detach().clone() - return target - - def _realize_cached_data(self) -> torch.Tensor: - # self._cached_data should be generated after the first call of this function - if self._cached_data is None: - if self._factory_method is not None: - # apply factory method - func, args, kwargs = self._factory_method - - # apply cached sequence - self._cached_data = self._apply_cache_buffer(func(*args, **kwargs)) - else: - # apply cached sequence only - self._cached_data = self._apply_cache_buffer() - return self._cached_data - - def _apply_cache_buffer(self, target=None) -> torch.Tensor: - # dump all cached sequence - # super-dainiu: support methods for single Tensor only def replace(x): if x is self: return target elif isinstance(x, LazyTensor): - return x._realize_cached_data() + return x._materialize_data() return x packed = None - for (func, args, kwargs) in self._cached_buffer: + for (func, args, kwargs) in self._op_buffer: if func == torch.Tensor.requires_grad_: packed = func, args, kwargs # requires grad should be set at last else: + self._pre_op_fn() o = func(*tree_map(replace, args), **tree_map(replace, kwargs)) target = o if isinstance(o, torch.Tensor) else target # if func returns non-Tensor, discard the value @@ -194,24 +176,23 @@ class LazyTensor(torch.Tensor): return target - # clear all means: - # 1. clear factory method - # 2. clear cached sequence - # 3. clear cached data - def _clear_all(self): - self._cached_data = None - self._cached_buffer = None - self._data = None - gc.collect() # avoid memory leak - # cache everything with __torch_function__ + @classmethod def __torch_function__(cls, func, types, args=(), kwargs=None): if kwargs is None: kwargs = {} - target = None + if func.__name__ in _EARLY_MATERIALIZED_OPS: + # These OPs cannot be lazy and related tensors should be early materialized + tree_map(cls._replace_with_materialized, args) + tree_map(cls._replace_with_materialized, kwargs) + is_inplace: bool = (func.__name__.endswith('_') and not (func.__name__.endswith('__')) + or func.__name__ == "__setitem__") if isinstance(func, torch._C.ScriptMethod): + # FIXME(ver217): torch script functions are not verified + + target = None def unwrap(x): if isinstance(x, LazyTensor): @@ -219,79 +200,83 @@ class LazyTensor(torch.Tensor): return x target: LazyTensor = args[0].clone() - target._cached_buffer.append((func, args, kwargs)) + target._op_buffer.append((func, args, kwargs)) target._meta_data = getattr(target._meta_data, func.name)(*tree_map(unwrap, args[1:]), **tree_map(unwrap, kwargs)) - - else: - - def unwrap(x): - nonlocal target - if isinstance(x, LazyTensor): - target = x if (func.__name__.endswith('_') and not (func.__name__.endswith('__')) - or func.__name__ == "__setitem__") else x.clone() - target._cached_buffer.append((func, args, kwargs)) - return x._meta_data - return x - - args = tree_map(unwrap, args) - kwargs = tree_map(unwrap, kwargs) - o = func(*args, **kwargs) - - if isinstance(o, MetaTensor): - target._meta_data = o return target else: - return o + + meta_to_lazy = {} + + def unwrap(x): + if isinstance(x, LazyTensor): + if x._materialized_data is not None: + # for early materialized tensor, use its materialized data directly + return x._materialized_data + t = x if is_inplace else x.clone() + t._op_buffer.append((func, args, kwargs)) + meta = x._meta_data.data + meta_to_lazy[meta] = t + return meta + return x + + def wrap(y, i=None): + if isinstance(y, MetaTensor): + if y in meta_to_lazy: + # inplace op, just return origin lazy tensor + return meta_to_lazy[y] + else: + # out of place op, create new lazy tensor + fn = lambda *a, **kw: func(*a, **kw) if i is None else func(*a, **kw)[i] + lazy_y = LazyTensor(fn, *args, meta_data=y, **kwargs) + return lazy_y + elif type(y) is Tensor: + # for early materialized tensor + with torch._C.DisableTorchFunction(): + meta = MetaTensor(y.new_empty(y.shape, dtype=y.dtype, device='meta'), fake_device=y.device) + lazy_y = LazyTensor(lambda: None, meta_data=meta) + lazy_y._materialized_data = y + return lazy_y + return y + + o = func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs)) + if isinstance(o, (tuple, list)): + return type(o)(wrap(y, i=i) for i, y in enumerate(o)) + return wrap(o) @classmethod def __torch_dispatch__(cls, func, types, args=(), kwargs=None): pass # skip def clone(self) -> "LazyTensor": - """Create a new ``LazyTensor`` with same cached sequence and factory method. - Returns: - LazyTensor: the new ``LazyTensor`` - """ - target = LazyTensor(orig_empty, 0, dtype=self._meta_data.dtype, device=self._meta_data.device) - target._factory_method = None - target._cached_buffer = list() - target._meta_data = self._meta_data.clone() - target._cached_data = self._cached_data.clone() if self._cached_data is not None else None - target._spec = copy.deepcopy(self._spec) + def factory_fn(): + return self.materialize().clone() + + target = LazyTensor(factory_fn, meta_data=self._meta_data) + return target - def detach(self) -> "LazyTensor": - target = self.clone() - target._cached_buffer.append((torch.Tensor.detach_, (self,), {})) - return target + def detach(self) -> Tensor: + return self @property - def spec(self) -> ShardingSpec: - return self._spec - - @spec.setter - def spec(self, other: ShardingSpec): - self._spec = other - - @property - def data(self) -> "LazyTensor": - return self._data.detach() + def data(self): + return self @data.setter - def data(self, other: "LazyTensor") -> "LazyTensor": - """This avoid the following infinite recursion, which is very common in ``nn.Module`` initialization. + def data(self, other: 'LazyTensor'): + raise NotImplementedError - Usage: - >>> a = LazyTensor(torch.empty, 0, dtype=torch.float32, device='cpu') - >>> b = a.cuda() - >>> a.data = b - """ - self._data = other + def tolist(self) -> list: + t = self.materialize() + return t.tolist() + + def __hash__(self): + return id(self) -class LazyInitContext(): +class LazyInitContext: """Context manager for lazy initialization. Enables initializing the model without allocating real memory. Usage: @@ -319,16 +304,21 @@ class LazyInitContext(): 1. Quantization strategies can be applied before allocating real memory. 2. Lazy initialization seems slower than normal initialization. """ + _replaced: bool = False - def __init__(self): + def __init__(self, tensor_cls: Union[_MyTensor, LazyTensor] = LazyTensor): self.overrides = {} + self.tensor_cls = tensor_cls def __enter__(self): + if LazyInitContext._replaced: + raise RuntimeError(f'LazyInitContext is not reentrant') + LazyInitContext._replaced = True def wrap_factory_method(target): # factory functions (eg. torch.empty()) def wrapper(*args, **kwargs): - return LazyTensor(target, *args, **kwargs) + return self.tensor_cls(target, *args, **kwargs) return wrapper, target @@ -336,7 +326,7 @@ class LazyInitContext(): # factory_like functions (eg. torch.empty_like()) def wrapper(*args, **kwargs): orig_t = args[0] - return LazyTensor(orig_target, *args[1:], device=orig_t.device, dtype=orig_t.dtype, **kwargs) + return self.tensor_cls(orig_target, *args[1:], device=orig_t.device, dtype=orig_t.dtype, **kwargs) return wrapper, target @@ -356,85 +346,51 @@ class LazyInitContext(): setattr(torch, name, wrapper) def __exit__(self, exc_type, exc_val, exc_tb): + LazyInitContext._replaced = False for name, (wrapper, orig) in self.overrides.items(): setattr(torch, name, orig) @staticmethod - def materialize(module: torch.nn.Module): + def materialize(module: torch.nn.Module, verbose: bool = False): """Initialize all ``nn.Parameter`` from ``LazyTensor``. Args: module (torch.nn.Module): Target ``nn.Module`` + verbose (bool): Whether to print lazy initialization rate. Defaults to False. """ + if verbose: + param_cnt = 0 + param_lazy_cnt = 0 + buf_cnt = 0 + buf_lazy_cnt = 0 @torch.no_grad() def init_recursively(module: nn.Module): + nonlocal param_cnt, param_lazy_cnt, buf_cnt, buf_lazy_cnt # recursively initialize the module for mod in module.children(): init_recursively(mod) # initialize tensors directly attached to the current module for name, param in module.named_parameters(recurse=False): + if verbose: + param_cnt += 1 + if param._materialized_data is None: + param_lazy_cnt += 1 setattr(module, name, param.materialize()) + param.clean() for name, buf in module.named_buffers(recurse=False): + if verbose: + buf_cnt += 1 + if buf._materialized_data is None: + buf_lazy_cnt += 1 setattr(module, name, buf.materialize()) + buf.clean() init_recursively(module) + + if verbose: + print(f'Param lazy rate: {param_lazy_cnt}/{param_cnt}') + print(f'Buffer lazy rate: {buf_lazy_cnt}/{buf_cnt}') return module - - @staticmethod - def distribute(module: torch.nn.Module): - """Initialize and shard all ``nn.Parameter`` from ``LazyTensor``. - - Args: - module (torch.nn.Module): Sharded target ``nn.Module`` - """ - - @torch.no_grad() - def init_recursively(module: nn.Module): - # recursively initialize the module - for mod in module.children(): - init_recursively(mod) - - # initialize tensors directly attached to the current module - for name, param in module.named_parameters(recurse=False): - setattr(module, name, param.distribute()) - - for name, buf in module.named_buffers(recurse=False): - setattr(module, name, buf.distribute()) - - init_recursively(module) - return module - - @staticmethod - @contextlib.contextmanager - def traceable(module: torch.nn.Module): - """Initialize all ``nn.Parameters`` as ``MetaTensor``. This enables ``ColoTracer`` with control flow. - - Args: - module (torch.nn.Module): Traceable ``nn.Module`` with ``MetaTensor`` as parameters. - """ - orig_val = dict() - - def init_recursively(module: nn.Module): - # recursively initialize the module - for mod in module.children(): - init_recursively(mod) - - # initialize tensors directly attached to the current module - for name, param in module.named_parameters(recurse=False): - setattr(module, name, param.traceable()) - orig_val[(module, name)] = param - - for name, buf in module.named_buffers(recurse=False): - setattr(module, name, buf.traceable()) - orig_val[(module, name)] = buf - - init_recursively(module) - - yield - - # restore original values - for (module, name), val in orig_val.items(): - setattr(module, name, val)