mirror of https://github.com/hpcaitech/ColossalAI
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
612 lines
23 KiB
612 lines
23 KiB
2 years ago
|
from types import MethodType
|
||
|
from typing import Callable, Optional, Union
|
||
2 years ago
|
|
||
|
import torch
|
||
2 years ago
|
import torch.distributed as dist
|
||
2 years ago
|
import torch.nn as nn
|
||
2 years ago
|
from torch import Tensor
|
||
2 years ago
|
from torch.utils._pytree import tree_map
|
||
|
|
||
2 years ago
|
from colossalai._analyzer._subclasses import MetaTensor
|
||
2 years ago
|
from colossalai.tensor.d_tensor.d_tensor import DTensor
|
||
|
from colossalai.tensor.d_tensor.layout import Layout
|
||
2 years ago
|
|
||
|
# reference: https://pytorch.org/cppdocs/notes/tensor_creation.html
|
||
2 years ago
|
_NORMAL_FACTORY = [
|
||
2 years ago
|
"arange",
|
||
|
"full",
|
||
2 years ago
|
"empty",
|
||
2 years ago
|
"linspace",
|
||
|
"logspace",
|
||
|
"ones",
|
||
|
"rand",
|
||
|
"randn",
|
||
|
"randint",
|
||
|
"randperm",
|
||
|
"zeros",
|
||
|
"tensor",
|
||
|
]
|
||
|
|
||
2 years ago
|
# factory function that does not support meta tensor backend
|
||
|
_NO_META_FACTORY = [
|
||
|
"eye",
|
||
|
]
|
||
|
|
||
2 years ago
|
_EARLY_MATERIALIZED_OPS = ['__getitem__', 'split']
|
||
2 years ago
|
|
||
2 years ago
|
# If your intent is to change the metadata of a Tensor (such as sizes / strides / storage / storage_offset)
|
||
|
# without autograd tracking the change, remove the .data / .detach() call and wrap the change in a `with torch.no_grad():` block.
|
||
|
# These ops cannot be unwrapped using .data
|
||
2 years ago
|
_CHANGE_META_OPS = ['_cudnn_rnn_flatten_weight', 'requires_grad_', '__get__', '__set__']
|
||
2 years ago
|
|
||
2 years ago
|
_LEGACY_TENSOR_CONSTRUCTOR = {
|
||
|
'FloatTensor': torch.float,
|
||
|
'DoubleTensor': torch.double,
|
||
|
'HalfTensor': torch.half,
|
||
|
'BFloat16Tensor': torch.bfloat16,
|
||
|
'ByteTensor': torch.uint8,
|
||
|
'CharTensor': torch.int8,
|
||
|
'ShortTensor': torch.short,
|
||
|
'IntTensor': torch.int,
|
||
|
'LongTensor': torch.long,
|
||
|
'BoolTensor': torch.bool,
|
||
|
}
|
||
|
|
||
2 years ago
|
_EMPTY_DATA = torch.empty(0)
|
||
|
|
||
2 years ago
|
|
||
|
class _MyTensor(Tensor):
|
||
|
"""This class is only for correctness verification.
|
||
|
"""
|
||
|
_pre_op_fn: Callable[['LazyTensor'], None] = lambda *args: None
|
||
|
|
||
2 years ago
|
def __new__(cls, func, *args, concrete_data=None, **kwargs) -> '_MyTensor':
|
||
2 years ago
|
cls._pre_op_fn()
|
||
2 years ago
|
if concrete_data is not None:
|
||
|
# uniform api as LazyTensor
|
||
|
data = concrete_data
|
||
|
else:
|
||
|
data = func(*args, **kwargs)
|
||
2 years ago
|
return Tensor._make_subclass(cls, data, require_grad=data.requires_grad)
|
||
|
|
||
|
@classmethod
|
||
|
def __torch_function__(cls, func, types, args=(), kwargs=None):
|
||
|
cls._pre_op_fn()
|
||
|
return super().__torch_function__(func, types, args, kwargs)
|
||
2 years ago
|
|
||
|
|
||
2 years ago
|
def _data_tolist(tensor: torch.Tensor) -> list:
|
||
|
"""tolist() method is not allowed for a subclass of tensor. Tensor.data returns a Tensor.
|
||
|
"""
|
||
|
return tensor.data.tolist()
|
||
|
|
||
|
|
||
2 years ago
|
def _convert_cls(tensor: 'LazyTensor', target: torch.Tensor) -> torch.Tensor:
|
||
|
"""Convert a lazy tensor's class to target's class, with target's data.
|
||
|
|
||
|
The reason why we change the class of a lazy tensor in-place is that this can easily handle shared modules/parameters, which is common in huggingface models.
|
||
|
If we create a new tensor and update the module by ``setattr(module, name, param)``, the shared parameters will not be updated. And we have to track all shared parameters and update them manually.
|
||
|
|
||
|
Args:
|
||
|
tensor (LazyTensor): the LazyTensor to be converted
|
||
|
target (torch.Tensor): target tensor
|
||
|
|
||
|
Returns:
|
||
|
torch.Tensor: the converted tensor
|
||
|
"""
|
||
|
cls_to_become = nn.Parameter if isinstance(tensor, nn.Parameter) else torch.Tensor
|
||
|
tensor.__class__ = cls_to_become
|
||
|
tensor.data = target
|
||
|
tensor.requires_grad = target.requires_grad
|
||
|
# subclass of torch.Tensor does not have tolist() method
|
||
|
# overwrite this method after materialization or distribution
|
||
2 years ago
|
tensor.tolist = MethodType(_data_tolist, tensor)
|
||
2 years ago
|
return tensor
|
||
|
|
||
|
|
||
2 years ago
|
class LazyTensor(torch.Tensor):
|
||
|
"""A naive implementation of LazyTensor (https://arxiv.org/pdf/2102.13267.pdf).
|
||
|
|
||
|
Usage:
|
||
|
1. Use ``LazyTensor`` instead of ``torch.Tensor``.
|
||
|
>>> x = LazyTensor(torch.zeros, 2, 3)
|
||
|
>>> x += 1
|
||
|
>>> y = x * x
|
||
|
>>> y = y.cuda().half()
|
||
|
>>> y[0, 0] = 0
|
||
|
>>> y = y.materialize() # materialize the tensor
|
||
|
>>> print(y)
|
||
|
tensor([[0., 1., 1.],
|
||
|
[1., 1., 1.]], device='cuda:0', dtype=torch.float16)
|
||
|
|
||
|
Warnings:
|
||
|
1. Cases that ``LazyTensor`` can't deal with.
|
||
|
>>> x = LazyTensor(torch.ones, 2, 3)
|
||
|
>>> x[0, 0] = -x[0, 0] # this will cause infinite recursion
|
||
2 years ago
|
>>> y = x.clone()
|
||
|
>>> x.add_(1) # modifying origin tensor after cloning leads to wrong materialization
|
||
|
>>> z = x.tolist()
|
||
|
>>> x.zeros_() # modifying origin tensor after cloning tolist is not allowed
|
||
2 years ago
|
>>> nn.utils.weight_norm(self.conv, name="weight", dim=2) # applying weight norm on a lazy tensor is not allowed
|
||
|
|
||
2 years ago
|
|
||
|
2. Cases that ``LazyTensor`` becomes eager (early materialization).
|
||
|
>>> b = a[:, 2:] # get a slice of a lazy tensor triggers early materialization
|
||
|
>>> chunks = a.split(3) # this also triggers early materialization
|
||
2 years ago
|
>>> x.data = torch.rand(2, 3) # directly setting data of a lazy tensor triggers early materialization
|
||
2 years ago
|
|
||
|
"""
|
||
|
|
||
|
_repr = True
|
||
|
_meta_data: Optional[MetaTensor] = None # shape, dtype, device
|
||
2 years ago
|
_pre_op_fn: Callable[['LazyTensor'], None] = lambda *args: None
|
||
2 years ago
|
|
||
|
@staticmethod
|
||
2 years ago
|
def __new__(cls, func, *args, meta_data=None, concrete_data=None, **kwargs):
|
||
|
if concrete_data is not None:
|
||
|
# some ops don't support meta backend and should have concrete data
|
||
|
elem = concrete_data
|
||
|
else:
|
||
|
if meta_data is None:
|
||
|
device = kwargs.get('device', 'cpu')
|
||
|
elem = func(*args, **{**kwargs, 'device': 'meta'})
|
||
2 years ago
|
meta_data = MetaTensor(elem, device=device)
|
||
2 years ago
|
elem = meta_data._tensor
|
||
2 years ago
|
# As a meta tensor cannot be modified __class__ to torch.Tensor, we should use an empty real tensor here
|
||
|
r = torch.Tensor._make_subclass(cls, _EMPTY_DATA, require_grad=elem.requires_grad)
|
||
2 years ago
|
r._meta_data = meta_data
|
||
2 years ago
|
return r
|
||
|
|
||
2 years ago
|
def __init__(self, func, *args, meta_data=None, concrete_data=None, **kwargs):
|
||
2 years ago
|
self._factory_method = (func, args, kwargs) # (func, args, kwargs)
|
||
|
self._op_buffer = [] # (func, args, kwargs, replace)
|
||
2 years ago
|
self._materialized_data: Optional[torch.Tensor] = concrete_data # materialized data
|
||
2 years ago
|
|
||
|
def materialize(self) -> torch.Tensor:
|
||
2 years ago
|
"""Materialize the ``LazyTensor`` to ``torch.Tensor`` by modifying __class__ (inplace).
|
||
2 years ago
|
|
||
|
Returns:
|
||
2 years ago
|
torch.Tensor: The materialized tensor (self).
|
||
2 years ago
|
"""
|
||
2 years ago
|
target = self._materialize_data()
|
||
2 years ago
|
self.clean()
|
||
|
return _convert_cls(self, target)
|
||
|
|
||
|
def distribute(self, layout: Layout) -> torch.Tensor:
|
||
|
"""Distribute the ``LazyTensor`` to ``torch.Tensor`` by modifying __class__ (inplace), according to the layout.
|
||
|
|
||
|
Args:
|
||
|
layout (Layout): Distribution layout.
|
||
|
|
||
|
Returns:
|
||
|
torch.Tensor: The distributed tensor (self).
|
||
|
"""
|
||
|
target = self._materialize_data()
|
||
|
self.clean()
|
||
|
local_tensor = DTensor(target, layout).local_tensor
|
||
|
return _convert_cls(self, local_tensor)
|
||
2 years ago
|
|
||
2 years ago
|
def clean(self) -> None:
|
||
|
"""Clean all stored operations, meta data and materialized data, which prevents memory leaking. This should be called after all tensors are materialized.
|
||
2 years ago
|
"""
|
||
2 years ago
|
self._factory_method = None
|
||
|
self._op_buffer = None
|
||
|
self._materialized_data = None
|
||
|
self._meta_data = None
|
||
2 years ago
|
|
||
2 years ago
|
@staticmethod
|
||
|
def _replace_with_materialized(x):
|
||
|
if isinstance(x, LazyTensor):
|
||
|
return x._materialize_data()
|
||
|
return x
|
||
2 years ago
|
|
||
2 years ago
|
def _materialize_data(self) -> torch.Tensor:
|
||
|
# self._materialized_data should be generated after the first call of this function
|
||
|
if self._materialized_data is None:
|
||
|
# apply factory method
|
||
|
func, args, kwargs = self._factory_method
|
||
2 years ago
|
|
||
2 years ago
|
# apply cached sequence
|
||
|
self._pre_op_fn()
|
||
2 years ago
|
|
||
2 years ago
|
try:
|
||
|
init_val = func(*tree_map(self._replace_with_materialized, args),
|
||
|
**tree_map(self._replace_with_materialized, kwargs))
|
||
|
except TypeError as e:
|
||
|
print(f'init fn: {func.__name__}')
|
||
|
raise e
|
||
|
|
||
|
self._materialized_data = self._rerun_ops(init_val)
|
||
|
return self._materialized_data
|
||
|
|
||
|
def _rerun_ops(self, target=None) -> torch.Tensor:
|
||
|
"""Do lazy execution by rerunning all (stored) related operations.
|
||
|
|
||
|
Args:
|
||
|
target (torc.Tensor, optional): Intial value of the target tensor (self). Defaults to None.
|
||
|
"""
|
||
2 years ago
|
|
||
|
def replace(x):
|
||
|
if x is self:
|
||
|
return target
|
||
|
elif isinstance(x, LazyTensor):
|
||
2 years ago
|
return x._materialize_data()
|
||
2 years ago
|
return x
|
||
|
|
||
|
packed = None
|
||
|
|
||
2 years ago
|
for (func, args, kwargs) in self._op_buffer:
|
||
2 years ago
|
if func == torch.Tensor.requires_grad_:
|
||
|
packed = func, args, kwargs # requires grad should be set at last
|
||
|
else:
|
||
2 years ago
|
self._pre_op_fn()
|
||
2 years ago
|
o = func(*tree_map(replace, args), **tree_map(replace, kwargs))
|
||
|
target = o if isinstance(o, torch.Tensor) else target # if func returns non-Tensor, discard the value
|
||
|
|
||
|
# super-dainiu: set requires_grad after all inplace-ops are done
|
||
|
if packed is not None:
|
||
|
func, args, kwargs = packed
|
||
|
func(*tree_map(replace, args), **tree_map(replace, kwargs))
|
||
|
|
||
|
return target
|
||
|
|
||
|
# cache everything with __torch_function__
|
||
2 years ago
|
|
||
2 years ago
|
@classmethod
|
||
|
def __torch_function__(cls, func, types, args=(), kwargs=None):
|
||
|
if kwargs is None:
|
||
|
kwargs = {}
|
||
2 years ago
|
if func.__name__ in _EARLY_MATERIALIZED_OPS:
|
||
|
# These OPs cannot be lazy and related tensors should be early materialized
|
||
|
tree_map(cls._replace_with_materialized, args)
|
||
|
tree_map(cls._replace_with_materialized, kwargs)
|
||
|
is_inplace: bool = (func.__name__.endswith('_') and not (func.__name__.endswith('__'))
|
||
2 years ago
|
or func.__name__ in ('__setitem__', '__set__'))
|
||
2 years ago
|
|
||
2 years ago
|
is_change_meta_op: bool = func.__name__ in _CHANGE_META_OPS
|
||
|
|
||
2 years ago
|
if isinstance(func, torch._C.ScriptMethod):
|
||
2 years ago
|
# FIXME(ver217): torch script functions are not verified
|
||
|
|
||
|
target = None
|
||
2 years ago
|
|
||
|
def unwrap(x):
|
||
|
if isinstance(x, LazyTensor):
|
||
|
return x._meta_data
|
||
|
return x
|
||
|
|
||
|
target: LazyTensor = args[0].clone()
|
||
2 years ago
|
target._op_buffer.append((func, args, kwargs))
|
||
2 years ago
|
target._meta_data = getattr(target._meta_data, func.name)(*tree_map(unwrap, args[1:]),
|
||
|
**tree_map(unwrap, kwargs))
|
||
2 years ago
|
return target
|
||
2 years ago
|
else:
|
||
|
|
||
2 years ago
|
meta_to_lazy = {}
|
||
|
|
||
2 years ago
|
def unwrap(x):
|
||
|
if isinstance(x, LazyTensor):
|
||
2 years ago
|
if x._materialized_data is not None:
|
||
|
# for early materialized tensor, use its materialized data directly
|
||
2 years ago
|
return x._materialized_data if is_change_meta_op else x._materialized_data.data
|
||
2 years ago
|
t = x if is_inplace else x.clone()
|
||
|
t._op_buffer.append((func, args, kwargs))
|
||
2 years ago
|
meta = x._meta_data if is_change_meta_op else x._meta_data.data
|
||
2 years ago
|
meta_to_lazy[meta] = t
|
||
|
return meta
|
||
2 years ago
|
return x
|
||
|
|
||
2 years ago
|
def wrap(y, i=None):
|
||
|
if isinstance(y, MetaTensor):
|
||
|
if y in meta_to_lazy:
|
||
|
# inplace op, just return origin lazy tensor
|
||
|
return meta_to_lazy[y]
|
||
|
else:
|
||
|
# out of place op, create new lazy tensor
|
||
|
fn = lambda *a, **kw: func(*a, **kw) if i is None else func(*a, **kw)[i]
|
||
|
lazy_y = LazyTensor(fn, *args, meta_data=y, **kwargs)
|
||
|
return lazy_y
|
||
|
elif type(y) is Tensor:
|
||
|
# for early materialized tensor
|
||
2 years ago
|
return LazyTensor(lambda: None, concrete_data=y)
|
||
2 years ago
|
return y
|
||
|
|
||
2 years ago
|
cls._pre_op_fn()
|
||
2 years ago
|
o = func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs))
|
||
|
if isinstance(o, (tuple, list)):
|
||
|
return type(o)(wrap(y, i=i) for i, y in enumerate(o))
|
||
|
return wrap(o)
|
||
2 years ago
|
|
||
|
@classmethod
|
||
|
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
||
|
pass # skip
|
||
|
|
||
|
def clone(self) -> "LazyTensor":
|
||
|
|
||
2 years ago
|
def factory_fn():
|
||
2 years ago
|
# if self is materialized, return self
|
||
|
new_tensor = self.materialize() if type(self) is LazyTensor else self
|
||
|
return new_tensor.clone()
|
||
2 years ago
|
|
||
2 years ago
|
target = LazyTensor(factory_fn, meta_data=self._meta_data)
|
||
2 years ago
|
|
||
2 years ago
|
return target
|
||
2 years ago
|
|
||
2 years ago
|
def detach(self) -> Tensor:
|
||
|
return self
|
||
2 years ago
|
|
||
2 years ago
|
def __deepcopy__(self, memo):
|
||
|
if not self.is_leaf:
|
||
|
raise RuntimeError("Only Tensors created explicitly by the user "
|
||
|
"(graph leaves) support the deepcopy protocol at the moment")
|
||
|
if id(self) in memo:
|
||
|
return memo[id(self)]
|
||
|
|
||
|
def factory_fn():
|
||
|
# if self is materialized, return self
|
||
|
new_tensor = self.materialize() if type(self) is LazyTensor else self
|
||
|
copied = new_tensor.detach().clone()
|
||
|
if new_tensor.requires_grad:
|
||
|
copied.requires_grad_()
|
||
|
return copied
|
||
|
|
||
1 year ago
|
if self._materialized_data is not None:
|
||
|
# self is early materialized
|
||
|
copied = self._materialized_data.detach().clone()
|
||
|
if self.requires_grad:
|
||
|
copied.requires_grad_()
|
||
|
target = LazyTensor(lambda: None, concrete_data=copied)
|
||
|
else:
|
||
|
target = LazyTensor(factory_fn, meta_data=self._meta_data)
|
||
2 years ago
|
|
||
|
memo[id(self)] = target
|
||
|
return target
|
||
|
|
||
2 years ago
|
@property
|
||
2 years ago
|
def data(self):
|
||
|
return self
|
||
2 years ago
|
|
||
|
@data.setter
|
||
2 years ago
|
def data(self, other: 'LazyTensor'):
|
||
2 years ago
|
"""This is sightly different from oringinal `data` setter.
|
||
|
|
||
|
E.g.:
|
||
|
>>> a = torch.randn(3, 3) # a is a Tensor
|
||
|
>>> b = torch.rand(2, 2)
|
||
|
>>> a.data = b
|
||
|
>>> b.add_(1) # this will affect a
|
||
|
>>> x = torch.randn(3, 3) # x is a LazyTensor
|
||
|
>>> y = torch.rand(2, 2) # y is a LazyTensor
|
||
|
>>> x.data = y
|
||
|
>>> y.add_(1) # this will not affect x
|
||
|
|
||
|
"""
|
||
2 years ago
|
if other is self:
|
||
|
return
|
||
2 years ago
|
|
||
|
self._op_buffer.append(other._factory_method)
|
||
|
|
||
|
def replace(x):
|
||
|
if x is other:
|
||
|
return self
|
||
|
return x
|
||
|
|
||
|
for func, args, kwargs in other._op_buffer:
|
||
|
self._op_buffer.append((func, tree_map(replace, args), tree_map(replace, kwargs)))
|
||
2 years ago
|
|
||
2 years ago
|
def tolist(self) -> list:
|
||
2 years ago
|
# Though self.__class__ is modified to torch.Tensor, in C++ side, it is still a subclass of torch.Tensor
|
||
|
# And subclass of torch.Tensor does not have tolist() method
|
||
|
t = self._materialize_data()
|
||
2 years ago
|
return t.tolist()
|
||
2 years ago
|
|
||
2 years ago
|
def __hash__(self):
|
||
|
return id(self)
|
||
2 years ago
|
|
||
2 years ago
|
|
||
|
class LazyInitContext:
|
||
2 years ago
|
"""Context manager for lazy initialization. Enables initializing the model without allocating real memory.
|
||
|
|
||
|
Usage:
|
||
|
1. The model is initialized, but no real memory is allocated.
|
||
|
>>> ctx = LazyInitContext()
|
||
|
>>> with ctx:
|
||
|
>>> model = MyModel().cuda()
|
||
|
|
||
|
2. The model is initialized with ``MetaTensor`` as weights, but still no real memory is allocated.
|
||
|
>>> with ctx.traceable(model):
|
||
|
>>> gm = symbolic_trace(model, meta_args=meta_args)
|
||
|
>>> # Solve the execution strategy and apply the strategy to the model
|
||
|
>>> strategy = StrategyAndSpec()
|
||
|
|
||
|
3. The model is initialized with ``torch.Tensor`` as weights, and real memory is allocated. (single device)
|
||
|
>>> model = ctx.materialize(model)
|
||
|
|
||
|
3. The model is initialized with sharded ``torch.Tensor`` as weights, and real memory is allocated. (distributed scenario)
|
||
|
>>> model = apply_strategy_to_all_params(model, strategy)
|
||
|
>>> model = ctx.distribute(model)
|
||
|
|
||
|
Warnings:
|
||
|
This API is still experimental and further modifications can be made to it.
|
||
|
For example:
|
||
|
1. Quantization strategies can be applied before allocating real memory.
|
||
|
2. Lazy initialization seems slower than normal initialization.
|
||
|
"""
|
||
2 years ago
|
_replaced: bool = False
|
||
2 years ago
|
|
||
2 years ago
|
def __init__(self, tensor_cls: Union[_MyTensor, LazyTensor] = LazyTensor):
|
||
2 years ago
|
self.overrides = {}
|
||
2 years ago
|
self.tensor_cls = tensor_cls
|
||
2 years ago
|
|
||
|
def __enter__(self):
|
||
2 years ago
|
if LazyInitContext._replaced:
|
||
|
raise RuntimeError(f'LazyInitContext is not reentrant')
|
||
|
LazyInitContext._replaced = True
|
||
2 years ago
|
|
||
|
def wrap_factory_method(target):
|
||
|
# factory functions (eg. torch.empty())
|
||
|
def wrapper(*args, **kwargs):
|
||
2 years ago
|
return self.tensor_cls(target, *args, **kwargs)
|
||
2 years ago
|
|
||
|
return wrapper, target
|
||
|
|
||
|
def wrap_factory_like_method(orig_target, target):
|
||
|
# factory_like functions (eg. torch.empty_like())
|
||
|
def wrapper(*args, **kwargs):
|
||
|
orig_t = args[0]
|
||
2 years ago
|
return self.tensor_cls(orig_target, *args[1:], device=orig_t.device, dtype=orig_t.dtype, **kwargs)
|
||
2 years ago
|
|
||
|
return wrapper, target
|
||
|
|
||
2 years ago
|
def wrap_legacy_constructor(target, dtype):
|
||
|
# legacy constructor (e.g. torch.LongTensor())
|
||
|
def wrapper(*args, **kwargs):
|
||
|
if len(args) == 1 and isinstance(args[0], torch.Tensor):
|
||
|
# (Tensor other)
|
||
|
return args[0]
|
||
|
elif len(args) == 1:
|
||
|
# (object data, *, torch.device device)
|
||
|
kwargs = {**kwargs, 'dtype': dtype}
|
||
|
replaced, orig = self.overrides['tensor']
|
||
|
return replaced(*args, **kwargs)
|
||
|
elif _is_int_tuple(args):
|
||
|
# (tuple of ints size, *, torch.device device)
|
||
|
kwargs = {**kwargs, 'dtype': dtype}
|
||
|
replaced, orig = self.overrides['empty']
|
||
|
return replaced(*args, **kwargs)
|
||
|
else:
|
||
|
raise TypeError(
|
||
|
f'new() received an invalid combination of arguments - got {tuple(type(x) for x in args)}, but expected one of:\n * (Tensor other)\n * (tuple of ints size, *, torch.device device)\n * (object data, *, torch.device device)'
|
||
|
)
|
||
|
|
||
|
return wrapper, target
|
||
|
|
||
|
def wrap_no_meta_factory(target):
|
||
|
# factory functions which don't support meta tensor backend
|
||
|
def wrapper(*args, **kwargs):
|
||
|
tensor = target(*args, **kwargs)
|
||
|
return self.tensor_cls(lambda: None, concrete_data=tensor)
|
||
|
|
||
|
return wrapper, target
|
||
|
|
||
2 years ago
|
self.overrides = {
|
||
|
target: wrap_factory_method(getattr(torch, target))
|
||
2 years ago
|
for target in _NORMAL_FACTORY
|
||
2 years ago
|
if callable(getattr(torch, target, None))
|
||
|
}
|
||
|
|
||
|
self.overrides.update({
|
||
|
target + '_like': wrap_factory_like_method(getattr(torch, target), getattr(torch, target + '_like'))
|
||
2 years ago
|
for target in _NORMAL_FACTORY
|
||
2 years ago
|
if callable(getattr(torch, target + '_like', None))
|
||
|
})
|
||
|
|
||
2 years ago
|
self.overrides.update({
|
||
|
target: wrap_legacy_constructor(getattr(torch, target), dtype)
|
||
|
for target, dtype in _LEGACY_TENSOR_CONSTRUCTOR.items()
|
||
|
if callable(getattr(torch, target, None))
|
||
|
})
|
||
|
|
||
|
self.overrides.update({
|
||
|
target: wrap_no_meta_factory(getattr(torch, target))
|
||
|
for target in _NO_META_FACTORY
|
||
|
if callable(getattr(torch, target, None))
|
||
|
})
|
||
|
|
||
2 years ago
|
for name, (wrapper, orig) in self.overrides.items():
|
||
|
setattr(torch, name, wrapper)
|
||
|
|
||
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
||
2 years ago
|
LazyInitContext._replaced = False
|
||
2 years ago
|
for name, (wrapper, orig) in self.overrides.items():
|
||
|
setattr(torch, name, orig)
|
||
|
|
||
|
@staticmethod
|
||
2 years ago
|
def materialize(module: nn.Module, verbose: bool = False) -> nn.Module:
|
||
|
"""Initialize all ``nn.Parameter`` from ``LazyTensor``. This function will modify the module in-place.
|
||
2 years ago
|
|
||
|
Args:
|
||
2 years ago
|
module (nn.Module): Target ``nn.Module``
|
||
2 years ago
|
verbose (bool): Whether to print lazy initialization rate. Defaults to False.
|
||
2 years ago
|
"""
|
||
|
|
||
2 years ago
|
def apply_fn(name: str, p: LazyTensor):
|
||
|
p.materialize()
|
||
|
|
||
|
return _apply_to_lazy_module(module, apply_fn, verbose)
|
||
|
|
||
|
@staticmethod
|
||
|
def distribute(module: nn.Module, layout_dict: dict, verbose: bool = False) -> nn.Module:
|
||
|
"""Distribute all ``nn.Parameter`` from ``LazyTensor``. This function will modify the module in-place.
|
||
|
|
||
|
Args:
|
||
|
module (nn.Module): Target ``nn.Module``
|
||
|
layout_dict (dict): Dict of layout for each parameter/buffer. The key is the parameter/buffer name, and the value is the layout.
|
||
|
verbose (bool, optional): Whether to print lazy initialization rate. Defaults to False.
|
||
|
"""
|
||
|
|
||
|
def apply_fn(name: str, p: LazyTensor):
|
||
|
p.distribute(layout_dict[name])
|
||
|
|
||
|
return _apply_to_lazy_module(module, apply_fn, verbose)
|
||
|
|
||
2 years ago
|
|
||
2 years ago
|
def _apply_to_lazy_module(module: nn.Module,
|
||
|
apply_fn: Callable[[str, torch.Tensor], None],
|
||
|
verbose: bool = False) -> nn.Module:
|
||
|
if verbose:
|
||
|
# verbose info
|
||
|
param_cnt = 0
|
||
|
param_lazy_cnt = 0
|
||
|
buf_cnt = 0
|
||
|
buf_lazy_cnt = 0
|
||
|
total_numel = 0
|
||
|
non_lazy_numel = 0
|
||
|
|
||
|
for name, p in module.named_parameters():
|
||
|
if verbose:
|
||
|
param_cnt += 1
|
||
|
total_numel += p.numel()
|
||
|
if getattr(p, '_materialized_data', False) is None:
|
||
|
# if no _materialized_data attr, the tensor is not lazy
|
||
|
param_lazy_cnt += 1
|
||
|
else:
|
||
|
non_lazy_numel += p.numel()
|
||
|
if isinstance(p, LazyTensor):
|
||
|
apply_fn(name, p)
|
||
2 years ago
|
|
||
2 years ago
|
for name, buf in module.named_buffers():
|
||
2 years ago
|
if verbose:
|
||
2 years ago
|
buf_cnt += 1
|
||
|
total_numel += buf.numel()
|
||
|
if getattr(buf, "_materialized_data", False) is None:
|
||
|
# if no _materialized_data attr, the tensor is not lazy
|
||
|
buf_lazy_cnt += 1
|
||
|
else:
|
||
|
non_lazy_numel += buf.numel()
|
||
|
if isinstance(buf, LazyTensor):
|
||
|
apply_fn(name, buf)
|
||
|
|
||
|
if verbose:
|
||
|
non_lazy_numel_ratio = non_lazy_numel / total_numel * 100 if non_lazy_numel != 0 else 0
|
||
|
_print_rank_0(f'Param lazy rate: {param_lazy_cnt}/{param_cnt}')
|
||
|
_print_rank_0(f'Buffer lazy rate: {buf_lazy_cnt}/{buf_cnt}')
|
||
|
_print_rank_0(
|
||
|
f'Non lazy numel: {non_lazy_numel} ({non_lazy_numel/1024**2:.3f} M), ratio: {non_lazy_numel_ratio}%')
|
||
|
|
||
|
return module
|
||
|
|
||
|
|
||
|
def _print_rank_0(*args, **kwargs):
|
||
|
if not dist.is_initialized() or dist.get_rank() == 0:
|
||
|
print(*args, **kwargs)
|
||
2 years ago
|
|
||
|
|
||
|
def _is_int_tuple(args) -> bool:
|
||
|
if not isinstance(args, tuple):
|
||
|
return False
|
||
|
for x in args:
|
||
|
if not isinstance(x, int):
|
||
|
return False
|
||
|
return True
|