from types import MethodType
from typing import Callable , Optional , Union
import torch
import torch . nn as nn
from packaging import version
from torch import Tensor
from torch . nn import Parameter
from torch . utils . _pytree import tree_map
from colossalai . logging import get_dist_logger
from . construction import ConstructorManager
from . pretrained import PretrainedManager
import colossalai . _analyzer . _subclasses . _meta_registration # noqa
# reference: https://pytorch.org/cppdocs/notes/tensor_creation.html
_NORMAL_FACTORY = [
" arange " ,
" full " ,
" empty " ,
" linspace " ,
" logspace " ,
" ones " ,
" rand " ,
" randn " ,
" randint " ,
" randperm " ,
" zeros " ,
" tensor " ,
]
# factory function that does not support meta tensor backend
_NO_META_FACTORY = [
" eye " ,
]
_EARLY_MATERIALIZED_OPS = [ " __getitem__ " , " split " ]
# If your intent is to change the metadata of a Tensor (such as sizes / strides / storage / storage_offset)
# without autograd tracking the change, remove the .data / .detach() call and wrap the change in a `with torch.no_grad():` block.
# These ops cannot be unwrapped using .data
_CHANGE_META_OPS = [ " _cudnn_rnn_flatten_weight " , " requires_grad_ " , " __get__ " , " __set__ " , " numel " , " size " , " dim " ]
# These ops is not related to tensor value and should not be rerun
_NO_RERUN_OPS = [ " __get__ " , " numel " , " size " , " dim " ]
_LEGACY_TENSOR_CONSTRUCTOR = {
" FloatTensor " : torch . float ,
" DoubleTensor " : torch . double ,
" HalfTensor " : torch . half ,
" BFloat16Tensor " : torch . bfloat16 ,
" ByteTensor " : torch . uint8 ,
" CharTensor " : torch . int8 ,
" ShortTensor " : torch . short ,
" IntTensor " : torch . int ,
" LongTensor " : torch . long ,
" BoolTensor " : torch . bool ,
}
# These ops have at least one lazy tensor argument and maybe a scalar argument
# scalar value should be converted to meta tensor
# this is a hack for torch 2.0
_EXPAND_SCALAR_OPS = [
" where " ,
" clamp " ,
" clamp_min " ,
" clamp_max " ,
" clamp_ " ,
" clamp_min_ " ,
" clamp_max_ " ,
]
_old_tensor_factory = torch . tensor
_EMPTY_DATA = torch . empty ( 0 )
class _MyTensor ( Tensor ) :
""" This class is only for correctness verification. """
_pre_op_fn : Callable [ [ " LazyTensor " ] , None ] = lambda * args : None
default_device : Optional [ torch . device ] = None
def __new__ ( cls , func , * args , concrete_data = None , * * kwargs ) - > " _MyTensor " :
cls . _pre_op_fn ( )
if concrete_data is not None :
# uniform api as LazyTensor
data = concrete_data
else :
kwargs [ " device " ] = cls . default_device
data = func ( * args , * * kwargs )
return Tensor . _make_subclass ( cls , data , require_grad = data . requires_grad )
@classmethod
def __torch_function__ ( cls , func , types , args = ( ) , kwargs = None ) :
cls . _pre_op_fn ( )
return super ( ) . __torch_function__ ( func , types , args , kwargs )
def _data_tolist ( tensor : torch . Tensor ) - > list :
""" tolist() method is not allowed for a subclass of tensor. Tensor.data returns a Tensor. """
return tensor . data . tolist ( )
def _convert_cls ( tensor : " LazyTensor " , target : torch . Tensor ) - > torch . Tensor :
""" Convert a lazy tensor ' s class to target ' s class, with target ' s data.
The reason why we change the class of a lazy tensor in - place is that this can easily handle shared modules / parameters , which is common in huggingface models .
If we create a new tensor and update the module by ` ` setattr ( module , name , param ) ` ` , the shared parameters will not be updated . And we have to track all shared parameters and update them manually .
Args :
tensor ( LazyTensor ) : the LazyTensor to be converted
target ( torch . Tensor ) : target tensor
Returns :
torch . Tensor : the converted tensor
"""
cls_to_become = Parameter if isinstance ( tensor , Parameter ) else torch . Tensor
tensor . __class__ = cls_to_become
if cls_to_become is Parameter :
# to fit UninitializedParameter
delattr ( tensor , " _is_param " )
tensor . data = target
tensor . requires_grad = target . requires_grad
# subclass of torch.Tensor does not have tolist() method
# overwrite this method after materialization or distribution
tensor . tolist = MethodType ( _data_tolist , tensor )
return tensor
class LazyTensor ( torch . Tensor ) :
""" A naive implementation of LazyTensor (https://arxiv.org/pdf/2102.13267.pdf).
Usage :
1. Use ` ` LazyTensor ` ` instead of ` ` torch . Tensor ` ` .
>> > x = LazyTensor ( torch . zeros , 2 , 3 )
>> > x + = 1
>> > y = x * x
>> > y = y . cuda ( ) . half ( )
>> > y [ 0 , 0 ] = 0
>> > y = y . materialize ( ) # materialize the tensor
>> > print ( y )
tensor ( [ [ 0. , 1. , 1. ] ,
[ 1. , 1. , 1. ] ] , device = ' cuda:0 ' , dtype = torch . float16 )
Warnings :
1. Cases that ` ` LazyTensor ` ` can ' t deal with.
>> > x = LazyTensor ( torch . ones , 2 , 3 )
>> > x [ 0 , 0 ] = - x [ 0 , 0 ] # this will cause infinite recursion
>> > y = x . clone ( )
>> > x . add_ ( 1 ) # modifying origin tensor after cloning leads to wrong materialization
>> > z = x . tolist ( )
>> > x . zeros_ ( ) # modifying origin tensor after cloning tolist is not allowed
>> > nn . utils . weight_norm ( self . conv , name = " weight " , dim = 2 ) # applying weight norm on a lazy tensor is not allowed
2. Cases that ` ` LazyTensor ` ` becomes eager ( early materialization ) .
>> > b = a [ : , 2 : ] # get a slice of a lazy tensor triggers early materialization
>> > chunks = a . split ( 3 ) # this also triggers early materialization
>> > x . data = torch . rand ( 2 , 3 ) # directly setting data of a lazy tensor triggers early materialization
"""
_repr = True
_meta_data : Optional [ torch . Tensor ] = None # shape, dtype, device
_pre_op_fn : Callable [ [ " LazyTensor " ] , None ] = lambda * args : None
default_device : Optional [ torch . device ] = None
_device : torch . device # fake device of mate tensor
@staticmethod
def __new__ ( cls , func , * args , meta_data = None , concrete_data = None , * * kwargs ) :
# tips for torch 2.0:
# torch 2.0 disables torch dispatch for subclass of tensor
# MetaTensor is cannot be used
# Now lazy tensor contains device injection and meta tensor
if concrete_data is not None :
# some ops don't support meta backend and should have concrete data
elem = concrete_data
else :
if meta_data is None :
with ConstructorManager . disable ( ) :
# to disable create lazy tensor in inner ops, this is a hack for torch 2.0
meta_data = func ( * args , * * { * * kwargs , " device " : " meta " } )
elem = meta_data
# As a meta tensor cannot be modified __class__ to torch.Tensor, we should use an empty real tensor here
r = torch . Tensor . _make_subclass ( cls , _EMPTY_DATA , require_grad = elem . requires_grad )
r . _meta_data = meta_data
return r
def __init__ ( self , func , * args , meta_data = None , concrete_data = None , * * kwargs ) :
self . _device = torch . device ( kwargs . get ( " device " , None ) or " cpu " )
if func . __name__ in _NORMAL_FACTORY :
kwargs = { * * kwargs , " device " : LazyTensor . default_device }
self . _factory_method = ( func , args , kwargs ) # (func, args, kwargs)
self . _op_buffer = [ ] # (func, args, kwargs, replace)
self . _materialized_data : Optional [ torch . Tensor ] = concrete_data # materialized data
@property
def device ( self ) - > torch . device :
return self . _materialized_data . device if self . _materialized_data is not None else self . _device
def __repr__ ( self ) :
return f " LazyTensor(..., size= { tuple ( self . shape ) } , device= ' { self . device } ' , dtype= { self . dtype } ) "
def materialize ( self ) - > torch . Tensor :
""" Materialize the ``LazyTensor`` to ``torch.Tensor`` by modifying __class__ (inplace).
Returns :
torch . Tensor : The materialized tensor ( self ) .
"""
target = self . _materialize_data ( )
self . clean ( )
return _convert_cls ( self , target )
def clean ( self ) - > None :
""" Clean all stored operations, meta data and materialized data, which prevents memory leaking. This should be called after all tensors are materialized. """
delattr ( self , " _factory_method " )
delattr ( self , " _op_buffer " )
delattr ( self , " _materialized_data " )
delattr ( self , " _meta_data " )
@staticmethod
def _replace_with_materialized ( x ) :
if isinstance ( x , LazyTensor ) :
return x . _materialize_data ( )
return x
def _materialize_data ( self ) - > torch . Tensor :
# self._materialized_data should be generated after the first call of this function
if self . _materialized_data is None :
# apply factory method
func , args , kwargs = self . _factory_method
# apply cached sequence
self . _pre_op_fn ( )
init_val = func (
* tree_map ( self . _replace_with_materialized , args ) , * * tree_map ( self . _replace_with_materialized , kwargs )
)
self . _materialized_data = self . _rerun_ops ( init_val )
return self . _materialized_data
def _rerun_ops ( self , target = None ) - > torch . Tensor :
""" Do lazy execution by rerunning all (stored) related operations.
Args :
target ( torc . Tensor , optional ) : Intial value of the target tensor ( self ) . Defaults to None .
"""
def replace ( x ) :
if x is self :
return target
elif isinstance ( x , LazyTensor ) :
return x . _materialize_data ( )
return x
packed = None
for func , args , kwargs in self . _op_buffer :
if func == torch . Tensor . requires_grad_ :
packed = func , args , kwargs # requires grad should be set at last
else :
self . _pre_op_fn ( )
o = func ( * tree_map ( replace , args ) , * * tree_map ( replace , kwargs ) )
target = o if isinstance ( o , torch . Tensor ) else target # if func returns non-Tensor, discard the value
# super-dainiu: set requires_grad after all inplace-ops are done
if packed is not None :
func , args , kwargs = packed
func ( * tree_map ( replace , args ) , * * tree_map ( replace , kwargs ) )
return target
# cache everything with __torch_function__
@classmethod
def __torch_function__ ( cls , func , types , args = ( ) , kwargs = None ) :
if kwargs is None :
kwargs = { }
if func . __name__ in _EARLY_MATERIALIZED_OPS :
# These OPs cannot be lazy and related tensors should be early materialized
tree_map ( cls . _replace_with_materialized , args )
tree_map ( cls . _replace_with_materialized , kwargs )
is_inplace : bool = (
func . __name__ . endswith ( " _ " )
and not ( func . __name__ . endswith ( " __ " ) )
or func . __name__ in ( " __setitem__ " , " __set__ " )
)
is_change_meta_op : bool = func . __name__ in _CHANGE_META_OPS
if isinstance ( func , torch . _C . ScriptMethod ) :
# FIXME(ver217): torch script functions are not verified
target = None
def unwrap ( x ) :
if isinstance ( x , LazyTensor ) :
return x . _meta_data
return x
target : LazyTensor = args [ 0 ] . clone ( )
target . _op_buffer . append ( ( func , args , kwargs ) )
target . _meta_data = getattr ( target . _meta_data , func . name ) (
* tree_map ( unwrap , args [ 1 : ] ) , * * tree_map ( unwrap , kwargs )
)
return target
else :
meta_to_lazy = { }
def unwrap ( x ) :
if isinstance ( x , LazyTensor ) :
if x . _materialized_data is not None :
# for early materialized tensor, use its materialized data directly
return x . _materialized_data if is_change_meta_op else x . _materialized_data . data
t = x if is_inplace else x . clone ( )
if func . __name__ not in _NO_RERUN_OPS :
t . _op_buffer . append ( ( func , args , kwargs ) )
meta = x . _meta_data if is_change_meta_op else x . _meta_data . data
meta_to_lazy [ meta ] = t
return meta
elif (
version . parse ( torch . __version__ ) > = version . parse ( " 2.0.0 " )
and func . __name__ in _EXPAND_SCALAR_OPS
and not isinstance ( x , torch . Tensor )
) :
return _old_tensor_factory ( x , device = " meta " )
return x
def wrap ( y , i = None ) :
if isinstance ( y , torch . Tensor ) :
if y . is_meta :
if y in meta_to_lazy :
# inplace op, just return origin lazy tensor
return meta_to_lazy [ y ]
else :
# out of place op, create new lazy tensor
fn = lambda * a , * * kw : func ( * a , * * kw ) if i is None else func ( * a , * * kw ) [ i ]
fn . __name__ = func . __name__
lazy_y = LazyTensor ( fn , * args , meta_data = y , * * kwargs )
return lazy_y
else :
# for early materialized tensor
return LazyTensor ( lambda : None , concrete_data = y )
return y
cls . _pre_op_fn ( )
with ConstructorManager . disable ( ) :
# to disable create lazy tensor in inner ops, this is a hack for torch 2.0
o = func ( * tree_map ( unwrap , args ) , * * tree_map ( unwrap , kwargs ) )
if isinstance ( o , ( tuple , list ) ) :
return type ( o ) ( wrap ( y , i = i ) for i , y in enumerate ( o ) )
return wrap ( o )
def to ( self , * args , * * kwargs ) - > torch . Tensor :
if self . _materialized_data is not None :
return LazyTensor ( lambda : None , concrete_data = self . _materialized_data . to ( * args , * * kwargs ) )
device = None
def replace ( x ) :
nonlocal device
if isinstance ( x , ( str , int , torch . device ) ) and not isinstance ( x , bool ) :
device = x
return torch . device ( " meta " )
return x
meta_data = self . _meta_data . to ( * tree_map ( replace , args ) , * * tree_map ( replace , kwargs ) )
if meta_data is self . _meta_data and device == self . device :
return self
def factory_fn ( t : torch . Tensor , * * kw ) :
return t . to ( * args , * * kwargs )
return LazyTensor ( factory_fn , self , meta_data = meta_data , device = device )
def cpu ( self , memory_format : torch . memory_format = torch . preserve_format ) :
return self . to ( device = torch . device ( " cpu " ) , memory_format = memory_format )
def cuda ( self , device = None , non_blocking = False , memory_format : torch . memory_format = torch . preserve_format ) :
device = torch . device ( device or " cuda " )
return self . to ( device = device , non_blocking = non_blocking , memory_format = memory_format )
def clone ( self ) - > " LazyTensor " :
def factory_fn ( t : torch . Tensor , * * kw ) :
# if self is materialized, return self
return t . clone ( )
target = LazyTensor ( factory_fn , self , meta_data = self . _meta_data )
return target
def detach ( self ) - > Tensor :
return self
def __deepcopy__ ( self , memo ) :
if not self . is_leaf :
raise RuntimeError (
" Only Tensors created explicitly by the user "
" (graph leaves) support the deepcopy protocol at the moment "
)
if id ( self ) in memo :
return memo [ id ( self ) ]
def factory_fn ( t : torch . Tensor , * * kw ) :
# if self is materialized, return self
return _copy_tensor ( t , t . requires_grad )
if self . _materialized_data is not None :
# self is early materialized
copied = _copy_tensor ( self . _materialized_data , self . requires_grad )
target = LazyTensor ( lambda : None , concrete_data = copied )
else :
target = LazyTensor ( factory_fn , self , meta_data = self . _meta_data )
if isinstance ( self , Parameter ) :
# hack isinstance check of parameter
target . _is_param = True
memo [ id ( self ) ] = target
return target
@property
def data ( self ) :
return self
@data.setter
def data ( self , other : " LazyTensor " ) :
""" This is sightly different from oringinal `data` setter.
E . g . :
>> > a = torch . randn ( 3 , 3 ) # a is a Tensor
>> > b = torch . rand ( 2 , 2 )
>> > a . data = b
>> > b . add_ ( 1 ) # this will affect a
>> > x = torch . randn ( 3 , 3 ) # x is a LazyTensor
>> > y = torch . rand ( 2 , 2 ) # y is a LazyTensor
>> > x . data = y
>> > y . add_ ( 1 ) # this will not affect x
"""
if other is self :
return
def replace ( x ) :
if x is other :
return self
return x
for func , args , kwargs in [ other . _factory_method , * other . _op_buffer ] :
self . _op_buffer . append ( ( func , tree_map ( replace , args ) , tree_map ( replace , kwargs ) ) )
def tolist ( self ) - > list :
# Though self.__class__ is modified to torch.Tensor, in C++ side, it is still a subclass of torch.Tensor
# And subclass of torch.Tensor does not have tolist() method
t = self . _materialize_data ( )
return t . tolist ( )
def __hash__ ( self ) :
return id ( self )
def __rpow__ ( self , other ) :
dtype = torch . result_type ( self , other )
return torch . tensor ( other , dtype = dtype , device = self . device ) * * self
class LazyInitContext :
""" Context manager for lazy initialization. Enables initializing the model without allocating real memory.
Args :
tensor_cls ( Union [ _MyTensor , LazyTensor ] , optional ) : This is only for test . Defaults to LazyTensor .
default_device ( Optional [ Union [ torch . device , str , int ] ] , optional ) : Defalt device for initialization .
If it ' s cuda, initilization will be accelerated, but cuda memory will be allocated. By default, it ' s cpu .
Defaults to None .
"""
_replaced : bool = False
def __init__ (
self ,
tensor_cls : Union [ _MyTensor , LazyTensor ] = LazyTensor ,
default_device : Optional [ Union [ torch . device , str , int ] ] = None ,
) :
assert tensor_cls is LazyTensor or tensor_cls is _MyTensor
self . tensor_cls = tensor_cls
self . old_default_device = LazyTensor . default_device
self . default_device = default_device
def __enter__ ( self ) :
if LazyInitContext . _replaced :
raise RuntimeError ( f " LazyInitContext is not reentrant " )
LazyInitContext . _replaced = True
self . old_default_device = self . tensor_cls . default_device
self . tensor_cls . default_device = self . default_device
def wrap_factory_method ( target ) :
# factory functions (eg. torch.empty())
def wrapper ( * args , * * kwargs ) :
return self . tensor_cls ( target , * args , * * kwargs )
return wrapper , target
def wrap_factory_like_method ( orig_target , target ) :
# factory_like functions (eg. torch.empty_like())
def wrapper ( * args , * * kwargs ) :
orig_t = args [ 0 ]
return self . tensor_cls (
orig_target , * orig_t . shape , * args [ 1 : ] , device = orig_t . device , dtype = orig_t . dtype , * * kwargs
)
return wrapper , target
def wrap_legacy_constructor ( target , dtype ) :
# legacy constructor (e.g. torch.LongTensor())
def wrapper ( * args , * * kwargs ) :
if len ( args ) == 1 and isinstance ( args [ 0 ] , torch . Tensor ) :
# (Tensor other)
return args [ 0 ]
elif len ( args ) == 1 :
# (object data, *, torch.device device)
kwargs = { * * kwargs , " dtype " : dtype }
replaced , orig = self . overrides [ " tensor " ]
return replaced ( * args , * * kwargs )
elif _is_int_tuple ( args ) :
# (tuple of ints size, *, torch.device device)
kwargs = { * * kwargs , " dtype " : dtype }
replaced , orig = self . overrides [ " empty " ]
return replaced ( * args , * * kwargs )
else :
raise TypeError (
f " new() received an invalid combination of arguments - got { tuple ( type ( x ) for x in args ) } , but expected one of: \n * (Tensor other) \n * (tuple of ints size, *, torch.device device) \n * (object data, *, torch.device device) "
)
return wrapper , target
def wrap_no_meta_factory ( target ) :
# factory functions which don't support meta tensor backend
def wrapper ( * args , * * kwargs ) :
tensor = target ( * args , * * kwargs )
return self . tensor_cls ( lambda : None , concrete_data = tensor )
return wrapper , target
overrides = {
target : wrap_factory_method ( getattr ( torch , target ) )
for target in _NORMAL_FACTORY
if callable ( getattr ( torch , target , None ) )
}
overrides . update (
{
target + " _like " : wrap_factory_like_method ( getattr ( torch , target ) , getattr ( torch , target + " _like " ) )
for target in _NORMAL_FACTORY
if callable ( getattr ( torch , target + " _like " , None ) )
}
)
overrides . update (
{
target : wrap_legacy_constructor ( getattr ( torch , target ) , dtype )
for target , dtype in _LEGACY_TENSOR_CONSTRUCTOR . items ( )
if callable ( getattr ( torch , target , None ) )
}
)
overrides . update (
{
target : wrap_no_meta_factory ( getattr ( torch , target ) )
for target in _NO_META_FACTORY
if callable ( getattr ( torch , target , None ) )
}
)
ConstructorManager . apply ( overrides )
PretrainedManager . inject ( )
def __exit__ ( self , exc_type , exc_val , exc_tb ) :
self . tensor_cls . default_device = self . old_default_device
LazyInitContext . _replaced = False
ConstructorManager . clear ( )
PretrainedManager . recover ( )
@staticmethod
def materialize ( module : nn . Module , verbose : bool = False ) - > nn . Module :
""" Initialize all ``Parameter`` from ``LazyTensor``. This function will modify the module in-place.
Args :
module ( nn . Module ) : Target ` ` nn . Module ` `
verbose ( bool ) : Whether to print lazy initialization rate . Defaults to False .
"""
def apply_fn ( name : str , p : LazyTensor ) :
p . materialize ( )
return _apply_to_lazy_module ( module , apply_fn , verbose )
def _apply_to_lazy_module (
module : nn . Module , apply_fn : Callable [ [ str , torch . Tensor ] , None ] , verbose : bool = False
) - > nn . Module :
if verbose :
# verbose info
param_cnt = 0
param_lazy_cnt = 0
buf_cnt = 0
buf_lazy_cnt = 0
total_numel = 0
non_lazy_numel = 0
for name , p in module . named_parameters ( ) :
if verbose :
param_cnt + = 1
total_numel + = p . numel ( )
if getattr ( p , " _materialized_data " , False ) is None :
# if no _materialized_data attr, the tensor is not lazy
param_lazy_cnt + = 1
else :
non_lazy_numel + = p . numel ( )
if isinstance ( p , LazyTensor ) :
apply_fn ( name , p )
for name , buf in module . named_buffers ( ) :
if verbose :
buf_cnt + = 1
total_numel + = buf . numel ( )
if getattr ( buf , " _materialized_data " , False ) is None :
# if no _materialized_data attr, the tensor is not lazy
buf_lazy_cnt + = 1
else :
non_lazy_numel + = buf . numel ( )
if isinstance ( buf , LazyTensor ) :
apply_fn ( name , buf )
if verbose :
non_lazy_numel_ratio = non_lazy_numel / total_numel * 100 if non_lazy_numel != 0 else 0
logger = get_dist_logger ( )
logger . info ( f " Param lazy rate: { param_lazy_cnt } / { param_cnt } " , ranks = [ 0 ] )
logger . info ( f " Buffer lazy rate: { buf_lazy_cnt } / { buf_cnt } " , ranks = [ 0 ] )
logger . info (
f " Non lazy numel: { non_lazy_numel } ( { non_lazy_numel / 1024 * * 2 : .3f } M), ratio: { non_lazy_numel_ratio } % " ,
ranks = [ 0 ] ,
)
return module
def _is_int_tuple ( args ) - > bool :
if not isinstance ( args , tuple ) :
return False
for x in args :
if not isinstance ( x , int ) :
return False
return True
def _copy_tensor ( tensor : Tensor , requires_grad : bool ) - > Tensor :
copied = tensor . data . clone ( )
copied . requires_grad = requires_grad
return copied