from types import MethodType from typing import Callable, Optional, Union import torch import torch.nn as nn from packaging import version from torch import Tensor from torch.nn import Parameter from torch.utils._pytree import tree_map from colossalai.logging import get_dist_logger from .construction import ConstructorManager from .pretrained import PretrainedManager import colossalai._analyzer._subclasses._meta_registration # noqa # reference: https://pytorch.org/cppdocs/notes/tensor_creation.html _NORMAL_FACTORY = [ "arange", "full", "empty", "linspace", "logspace", "ones", "rand", "randn", "randint", "randperm", "zeros", "tensor", ] # factory function that does not support meta tensor backend _NO_META_FACTORY = [ "eye", ] _EARLY_MATERIALIZED_OPS = ["__getitem__", "split"] # If your intent is to change the metadata of a Tensor (such as sizes / strides / storage / storage_offset) # without autograd tracking the change, remove the .data / .detach() call and wrap the change in a `with torch.no_grad():` block. # These ops cannot be unwrapped using .data _CHANGE_META_OPS = ["_cudnn_rnn_flatten_weight", "requires_grad_", "__get__", "__set__", "numel", "size", "dim"] # These ops is not related to tensor value and should not be rerun _NO_RERUN_OPS = ["__get__", "numel", "size", "dim"] _LEGACY_TENSOR_CONSTRUCTOR = { "FloatTensor": torch.float, "DoubleTensor": torch.double, "HalfTensor": torch.half, "BFloat16Tensor": torch.bfloat16, "ByteTensor": torch.uint8, "CharTensor": torch.int8, "ShortTensor": torch.short, "IntTensor": torch.int, "LongTensor": torch.long, "BoolTensor": torch.bool, } # These ops have at least one lazy tensor argument and maybe a scalar argument # scalar value should be converted to meta tensor # this is a hack for torch 2.0 _EXPAND_SCALAR_OPS = [ "where", "clamp", "clamp_min", "clamp_max", "clamp_", "clamp_min_", "clamp_max_", ] _old_tensor_factory = torch.tensor _EMPTY_DATA = torch.empty(0) class _MyTensor(Tensor): """This class is only for correctness verification.""" _pre_op_fn: Callable[["LazyTensor"], None] = lambda *args: None default_device: Optional[torch.device] = None def __new__(cls, func, *args, concrete_data=None, **kwargs) -> "_MyTensor": cls._pre_op_fn() if concrete_data is not None: # uniform api as LazyTensor data = concrete_data else: kwargs["device"] = cls.default_device data = func(*args, **kwargs) return Tensor._make_subclass(cls, data, require_grad=data.requires_grad) @classmethod def __torch_function__(cls, func, types, args=(), kwargs=None): cls._pre_op_fn() return super().__torch_function__(func, types, args, kwargs) def _data_tolist(tensor: torch.Tensor) -> list: """tolist() method is not allowed for a subclass of tensor. Tensor.data returns a Tensor.""" return tensor.data.tolist() def _convert_cls(tensor: "LazyTensor", target: torch.Tensor) -> torch.Tensor: """Convert a lazy tensor's class to target's class, with target's data. The reason why we change the class of a lazy tensor in-place is that this can easily handle shared modules/parameters, which is common in huggingface models. If we create a new tensor and update the module by ``setattr(module, name, param)``, the shared parameters will not be updated. And we have to track all shared parameters and update them manually. Args: tensor (LazyTensor): the LazyTensor to be converted target (torch.Tensor): target tensor Returns: torch.Tensor: the converted tensor """ cls_to_become = Parameter if isinstance(tensor, Parameter) else torch.Tensor tensor.__class__ = cls_to_become if cls_to_become is Parameter: # to fit UninitializedParameter delattr(tensor, "_is_param") tensor.data = target tensor.requires_grad = target.requires_grad # subclass of torch.Tensor does not have tolist() method # overwrite this method after materialization or distribution tensor.tolist = MethodType(_data_tolist, tensor) return tensor class LazyTensor(torch.Tensor): """A naive implementation of LazyTensor (https://arxiv.org/pdf/2102.13267.pdf). Usage: 1. Use ``LazyTensor`` instead of ``torch.Tensor``. >>> x = LazyTensor(torch.zeros, 2, 3) >>> x += 1 >>> y = x * x >>> y = y.cuda().half() >>> y[0, 0] = 0 >>> y = y.materialize() # materialize the tensor >>> print(y) tensor([[0., 1., 1.], [1., 1., 1.]], device='cuda:0', dtype=torch.float16) Warnings: 1. Cases that ``LazyTensor`` can't deal with. >>> x = LazyTensor(torch.ones, 2, 3) >>> x[0, 0] = -x[0, 0] # this will cause infinite recursion >>> y = x.clone() >>> x.add_(1) # modifying origin tensor after cloning leads to wrong materialization >>> z = x.tolist() >>> x.zeros_() # modifying origin tensor after cloning tolist is not allowed >>> nn.utils.weight_norm(self.conv, name="weight", dim=2) # applying weight norm on a lazy tensor is not allowed 2. Cases that ``LazyTensor`` becomes eager (early materialization). >>> b = a[:, 2:] # get a slice of a lazy tensor triggers early materialization >>> chunks = a.split(3) # this also triggers early materialization >>> x.data = torch.rand(2, 3) # directly setting data of a lazy tensor triggers early materialization """ _repr = True _meta_data: Optional[torch.Tensor] = None # shape, dtype, device _pre_op_fn: Callable[["LazyTensor"], None] = lambda *args: None default_device: Optional[torch.device] = None _device: torch.device # fake device of mate tensor @staticmethod def __new__(cls, func, *args, meta_data=None, concrete_data=None, **kwargs): # tips for torch 2.0: # torch 2.0 disables torch dispatch for subclass of tensor # MetaTensor is cannot be used # Now lazy tensor contains device injection and meta tensor if concrete_data is not None: # some ops don't support meta backend and should have concrete data elem = concrete_data else: if meta_data is None: with ConstructorManager.disable(): # to disable create lazy tensor in inner ops, this is a hack for torch 2.0 meta_data = func(*args, **{**kwargs, "device": "meta"}) elem = meta_data # As a meta tensor cannot be modified __class__ to torch.Tensor, we should use an empty real tensor here r = torch.Tensor._make_subclass(cls, _EMPTY_DATA, require_grad=elem.requires_grad) r._meta_data = meta_data return r def __init__(self, func, *args, meta_data=None, concrete_data=None, **kwargs): self._device = torch.device(kwargs.get("device", None) or "cpu") if func.__name__ in _NORMAL_FACTORY: kwargs = {**kwargs, "device": LazyTensor.default_device} self._factory_method = (func, args, kwargs) # (func, args, kwargs) self._op_buffer = [] # (func, args, kwargs, replace) self._materialized_data: Optional[torch.Tensor] = concrete_data # materialized data @property def device(self) -> torch.device: return self._materialized_data.device if self._materialized_data is not None else self._device def __repr__(self): return f"LazyTensor(..., size={tuple(self.shape)}, device='{self.device}', dtype={self.dtype})" def materialize(self) -> torch.Tensor: """Materialize the ``LazyTensor`` to ``torch.Tensor`` by modifying __class__ (inplace). Returns: torch.Tensor: The materialized tensor (self). """ target = self._materialize_data() self.clean() return _convert_cls(self, target) def clean(self) -> None: """Clean all stored operations, meta data and materialized data, which prevents memory leaking. This should be called after all tensors are materialized.""" delattr(self, "_factory_method") delattr(self, "_op_buffer") delattr(self, "_materialized_data") delattr(self, "_meta_data") @staticmethod def _replace_with_materialized(x): if isinstance(x, LazyTensor): return x._materialize_data() return x def _materialize_data(self) -> torch.Tensor: # self._materialized_data should be generated after the first call of this function if self._materialized_data is None: # apply factory method func, args, kwargs = self._factory_method # apply cached sequence self._pre_op_fn() init_val = func( *tree_map(self._replace_with_materialized, args), **tree_map(self._replace_with_materialized, kwargs) ) self._materialized_data = self._rerun_ops(init_val) return self._materialized_data def _rerun_ops(self, target=None) -> torch.Tensor: """Do lazy execution by rerunning all (stored) related operations. Args: target (torc.Tensor, optional): Intial value of the target tensor (self). Defaults to None. """ def replace(x): if x is self: return target elif isinstance(x, LazyTensor): return x._materialize_data() return x packed = None for func, args, kwargs in self._op_buffer: if func == torch.Tensor.requires_grad_: packed = func, args, kwargs # requires grad should be set at last else: self._pre_op_fn() o = func(*tree_map(replace, args), **tree_map(replace, kwargs)) target = o if isinstance(o, torch.Tensor) else target # if func returns non-Tensor, discard the value # super-dainiu: set requires_grad after all inplace-ops are done if packed is not None: func, args, kwargs = packed func(*tree_map(replace, args), **tree_map(replace, kwargs)) return target # cache everything with __torch_function__ @classmethod def __torch_function__(cls, func, types, args=(), kwargs=None): if kwargs is None: kwargs = {} if func.__name__ in _EARLY_MATERIALIZED_OPS: # These OPs cannot be lazy and related tensors should be early materialized tree_map(cls._replace_with_materialized, args) tree_map(cls._replace_with_materialized, kwargs) is_inplace: bool = ( func.__name__.endswith("_") and not (func.__name__.endswith("__")) or func.__name__ in ("__setitem__", "__set__") ) is_change_meta_op: bool = func.__name__ in _CHANGE_META_OPS if isinstance(func, torch._C.ScriptMethod): # FIXME(ver217): torch script functions are not verified target = None def unwrap(x): if isinstance(x, LazyTensor): return x._meta_data return x target: LazyTensor = args[0].clone() target._op_buffer.append((func, args, kwargs)) target._meta_data = getattr(target._meta_data, func.name)( *tree_map(unwrap, args[1:]), **tree_map(unwrap, kwargs) ) return target else: meta_to_lazy = {} def unwrap(x): if isinstance(x, LazyTensor): if x._materialized_data is not None: # for early materialized tensor, use its materialized data directly return x._materialized_data if is_change_meta_op else x._materialized_data.data t = x if is_inplace else x.clone() if func.__name__ not in _NO_RERUN_OPS: t._op_buffer.append((func, args, kwargs)) meta = x._meta_data if is_change_meta_op else x._meta_data.data meta_to_lazy[meta] = t return meta elif ( version.parse(torch.__version__) >= version.parse("2.0.0") and func.__name__ in _EXPAND_SCALAR_OPS and not isinstance(x, torch.Tensor) ): return _old_tensor_factory(x, device="meta") return x def wrap(y, i=None): if isinstance(y, torch.Tensor): if y.is_meta: if y in meta_to_lazy: # inplace op, just return origin lazy tensor return meta_to_lazy[y] else: # out of place op, create new lazy tensor fn = lambda *a, **kw: func(*a, **kw) if i is None else func(*a, **kw)[i] fn.__name__ = func.__name__ lazy_y = LazyTensor(fn, *args, meta_data=y, **kwargs) return lazy_y else: # for early materialized tensor return LazyTensor(lambda: None, concrete_data=y) return y cls._pre_op_fn() with ConstructorManager.disable(): # to disable create lazy tensor in inner ops, this is a hack for torch 2.0 o = func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs)) if isinstance(o, (tuple, list)): return type(o)(wrap(y, i=i) for i, y in enumerate(o)) return wrap(o) def to(self, *args, **kwargs) -> torch.Tensor: if self._materialized_data is not None: return LazyTensor(lambda: None, concrete_data=self._materialized_data.to(*args, **kwargs)) device = None def replace(x): nonlocal device if isinstance(x, (str, int, torch.device)) and not isinstance(x, bool): device = x return torch.device("meta") return x meta_data = self._meta_data.to(*tree_map(replace, args), **tree_map(replace, kwargs)) if meta_data is self._meta_data and device == self.device: return self def factory_fn(t: torch.Tensor, **kw): return t.to(*args, **kwargs) return LazyTensor(factory_fn, self, meta_data=meta_data, device=device) def cpu(self, memory_format: torch.memory_format = torch.preserve_format): return self.to(device=torch.device("cpu"), memory_format=memory_format) def cuda(self, device=None, non_blocking=False, memory_format: torch.memory_format = torch.preserve_format): device = torch.device(device or "cuda") return self.to(device=device, non_blocking=non_blocking, memory_format=memory_format) def clone(self) -> "LazyTensor": def factory_fn(t: torch.Tensor, **kw): # if self is materialized, return self return t.clone() target = LazyTensor(factory_fn, self, meta_data=self._meta_data) return target def detach(self) -> Tensor: return self def __deepcopy__(self, memo): if not self.is_leaf: raise RuntimeError( "Only Tensors created explicitly by the user " "(graph leaves) support the deepcopy protocol at the moment" ) if id(self) in memo: return memo[id(self)] def factory_fn(t: torch.Tensor, **kw): # if self is materialized, return self return _copy_tensor(t, t.requires_grad) if self._materialized_data is not None: # self is early materialized copied = _copy_tensor(self._materialized_data, self.requires_grad) target = LazyTensor(lambda: None, concrete_data=copied) else: target = LazyTensor(factory_fn, self, meta_data=self._meta_data) if isinstance(self, Parameter): # hack isinstance check of parameter target._is_param = True memo[id(self)] = target return target @property def data(self): return self @data.setter def data(self, other: "LazyTensor"): """This is sightly different from oringinal `data` setter. E.g.: >>> a = torch.randn(3, 3) # a is a Tensor >>> b = torch.rand(2, 2) >>> a.data = b >>> b.add_(1) # this will affect a >>> x = torch.randn(3, 3) # x is a LazyTensor >>> y = torch.rand(2, 2) # y is a LazyTensor >>> x.data = y >>> y.add_(1) # this will not affect x """ if other is self: return def replace(x): if x is other: return self return x for func, args, kwargs in [other._factory_method, *other._op_buffer]: self._op_buffer.append((func, tree_map(replace, args), tree_map(replace, kwargs))) def tolist(self) -> list: # Though self.__class__ is modified to torch.Tensor, in C++ side, it is still a subclass of torch.Tensor # And subclass of torch.Tensor does not have tolist() method t = self._materialize_data() return t.tolist() def __hash__(self): return id(self) def __rpow__(self, other): dtype = torch.result_type(self, other) return torch.tensor(other, dtype=dtype, device=self.device) ** self class LazyInitContext: """Context manager for lazy initialization. Enables initializing the model without allocating real memory. Args: tensor_cls (Union[_MyTensor, LazyTensor], optional): This is only for test. Defaults to LazyTensor. default_device (Optional[Union[torch.device, str, int]], optional): Defalt device for initialization. If it's cuda, initilization will be accelerated, but cuda memory will be allocated. By default, it's cpu. Defaults to None. """ _replaced: bool = False def __init__( self, tensor_cls: Union[_MyTensor, LazyTensor] = LazyTensor, default_device: Optional[Union[torch.device, str, int]] = None, ): assert tensor_cls is LazyTensor or tensor_cls is _MyTensor self.tensor_cls = tensor_cls self.old_default_device = LazyTensor.default_device self.default_device = default_device def __enter__(self): if LazyInitContext._replaced: raise RuntimeError(f"LazyInitContext is not reentrant") LazyInitContext._replaced = True self.old_default_device = self.tensor_cls.default_device self.tensor_cls.default_device = self.default_device def wrap_factory_method(target): # factory functions (eg. torch.empty()) def wrapper(*args, **kwargs): return self.tensor_cls(target, *args, **kwargs) return wrapper, target def wrap_factory_like_method(orig_target, target): # factory_like functions (eg. torch.empty_like()) def wrapper(*args, **kwargs): orig_t = args[0] return self.tensor_cls( orig_target, *orig_t.shape, *args[1:], device=orig_t.device, dtype=orig_t.dtype, **kwargs ) return wrapper, target def wrap_legacy_constructor(target, dtype): # legacy constructor (e.g. torch.LongTensor()) def wrapper(*args, **kwargs): if len(args) == 1 and isinstance(args[0], torch.Tensor): # (Tensor other) return args[0] elif len(args) == 1: # (object data, *, torch.device device) kwargs = {**kwargs, "dtype": dtype} replaced, orig = self.overrides["tensor"] return replaced(*args, **kwargs) elif _is_int_tuple(args): # (tuple of ints size, *, torch.device device) kwargs = {**kwargs, "dtype": dtype} replaced, orig = self.overrides["empty"] return replaced(*args, **kwargs) else: raise TypeError( f"new() received an invalid combination of arguments - got {tuple(type(x) for x in args)}, but expected one of:\n * (Tensor other)\n * (tuple of ints size, *, torch.device device)\n * (object data, *, torch.device device)" ) return wrapper, target def wrap_no_meta_factory(target): # factory functions which don't support meta tensor backend def wrapper(*args, **kwargs): tensor = target(*args, **kwargs) return self.tensor_cls(lambda: None, concrete_data=tensor) return wrapper, target overrides = { target: wrap_factory_method(getattr(torch, target)) for target in _NORMAL_FACTORY if callable(getattr(torch, target, None)) } overrides.update( { target + "_like": wrap_factory_like_method(getattr(torch, target), getattr(torch, target + "_like")) for target in _NORMAL_FACTORY if callable(getattr(torch, target + "_like", None)) } ) overrides.update( { target: wrap_legacy_constructor(getattr(torch, target), dtype) for target, dtype in _LEGACY_TENSOR_CONSTRUCTOR.items() if callable(getattr(torch, target, None)) } ) overrides.update( { target: wrap_no_meta_factory(getattr(torch, target)) for target in _NO_META_FACTORY if callable(getattr(torch, target, None)) } ) ConstructorManager.apply(overrides) PretrainedManager.inject() def __exit__(self, exc_type, exc_val, exc_tb): self.tensor_cls.default_device = self.old_default_device LazyInitContext._replaced = False ConstructorManager.clear() PretrainedManager.recover() @staticmethod def materialize(module: nn.Module, verbose: bool = False) -> nn.Module: """Initialize all ``Parameter`` from ``LazyTensor``. This function will modify the module in-place. Args: module (nn.Module): Target ``nn.Module`` verbose (bool): Whether to print lazy initialization rate. Defaults to False. """ def apply_fn(name: str, p: LazyTensor): p.materialize() return _apply_to_lazy_module(module, apply_fn, verbose) def _apply_to_lazy_module( module: nn.Module, apply_fn: Callable[[str, torch.Tensor], None], verbose: bool = False ) -> nn.Module: if verbose: # verbose info param_cnt = 0 param_lazy_cnt = 0 buf_cnt = 0 buf_lazy_cnt = 0 total_numel = 0 non_lazy_numel = 0 for name, p in module.named_parameters(): if verbose: param_cnt += 1 total_numel += p.numel() if getattr(p, "_materialized_data", False) is None: # if no _materialized_data attr, the tensor is not lazy param_lazy_cnt += 1 else: non_lazy_numel += p.numel() if isinstance(p, LazyTensor): apply_fn(name, p) for name, buf in module.named_buffers(): if verbose: buf_cnt += 1 total_numel += buf.numel() if getattr(buf, "_materialized_data", False) is None: # if no _materialized_data attr, the tensor is not lazy buf_lazy_cnt += 1 else: non_lazy_numel += buf.numel() if isinstance(buf, LazyTensor): apply_fn(name, buf) if verbose: non_lazy_numel_ratio = non_lazy_numel / total_numel * 100 if non_lazy_numel != 0 else 0 logger = get_dist_logger() logger.info(f"Param lazy rate: {param_lazy_cnt}/{param_cnt}", ranks=[0]) logger.info(f"Buffer lazy rate: {buf_lazy_cnt}/{buf_cnt}", ranks=[0]) logger.info( f"Non lazy numel: {non_lazy_numel} ({non_lazy_numel/1024**2:.3f} M), ratio: {non_lazy_numel_ratio}%", ranks=[0], ) return module def _is_int_tuple(args) -> bool: if not isinstance(args, tuple): return False for x in args: if not isinstance(x, int): return False return True def _copy_tensor(tensor: Tensor, requires_grad: bool) -> Tensor: copied = tensor.data.clone() copied.requires_grad = requires_grad return copied