@ -1,17 +1,18 @@
from types import MethodType
from types import MethodType
from typing import Callable , Dict, Optional, Union
from typing import Callable , Optional, Union
import torch
import torch
import torch . distributed as dist
import torch . nn as nn
import torch . nn as nn
from packaging import version
from torch import Tensor
from torch import Tensor
from torch . nn import Parameter
from torch . nn import Parameter
from torch . utils . _pytree import tree_map
from torch . utils . _pytree import tree_map
from colossalai . _analyzer . _subclasses import MetaTensor
from colossalai . logging import get_dist_logger
from colossalai . device . device_mesh import DeviceMesh
from colossalai . tensor . d_tensor import distribute_tensor
from . construction import ConstructorManager
from colossalai . tensor . d_tensor . sharding_spec import ShardingSpec
import colossalai . _analyzer . _subclasses . _meta_registration # noqa
# reference: https://pytorch.org/cppdocs/notes/tensor_creation.html
# reference: https://pytorch.org/cppdocs/notes/tensor_creation.html
_NORMAL_FACTORY = [
_NORMAL_FACTORY = [
@ -41,6 +42,9 @@ _EARLY_MATERIALIZED_OPS = ["__getitem__", "split"]
# These ops cannot be unwrapped using .data
# These ops cannot be unwrapped using .data
_CHANGE_META_OPS = [ " _cudnn_rnn_flatten_weight " , " requires_grad_ " , " __get__ " , " __set__ " , " numel " , " size " , " dim " ]
_CHANGE_META_OPS = [ " _cudnn_rnn_flatten_weight " , " requires_grad_ " , " __get__ " , " __set__ " , " numel " , " size " , " dim " ]
# These ops is not related to tensor value and should not be rerun
_NO_RERUN_OPS = [ " __get__ " , " numel " , " size " , " dim " ]
_LEGACY_TENSOR_CONSTRUCTOR = {
_LEGACY_TENSOR_CONSTRUCTOR = {
" FloatTensor " : torch . float ,
" FloatTensor " : torch . float ,
" DoubleTensor " : torch . double ,
" DoubleTensor " : torch . double ,
@ -54,6 +58,20 @@ _LEGACY_TENSOR_CONSTRUCTOR = {
" BoolTensor " : torch . bool ,
" BoolTensor " : torch . bool ,
}
}
# These ops have at least one lazy tensor argument and maybe a scalar argument
# scalar value should be converted to meta tensor
# this is a hack for torch 2.0
_EXPAND_SCALAR_OPS = [
" where " ,
" clamp " ,
" clamp_min " ,
" clamp_max " ,
" clamp_ " ,
" clamp_min_ " ,
" clamp_max_ " ,
]
_old_tensor_factory = torch . tensor
_EMPTY_DATA = torch . empty ( 0 )
_EMPTY_DATA = torch . empty ( 0 )
@ -145,34 +163,48 @@ class LazyTensor(torch.Tensor):
"""
"""
_repr = True
_repr = True
_meta_data : Optional [ Meta Tensor] = None # shape, dtype, device
_meta_data : Optional [ torch. Tensor] = None # shape, dtype, device
_pre_op_fn : Callable [ [ " LazyTensor " ] , None ] = lambda * args : None
_pre_op_fn : Callable [ [ " LazyTensor " ] , None ] = lambda * args : None
default_device : Optional [ torch . device ] = None
default_device : Optional [ torch . device ] = None
_device : torch . device # fake device of mate tensor
@staticmethod
@staticmethod
def __new__ ( cls , func , * args , meta_data = None , concrete_data = None , * * kwargs ) :
def __new__ ( cls , func , * args , meta_data = None , concrete_data = None , * * kwargs ) :
# tips for torch 2.0:
# torch 2.0 disables torch dispatch for subclass of tensor
# MetaTensor is cannot be used
# Now lazy tensor contains device injection and meta tensor
if concrete_data is not None :
if concrete_data is not None :
# some ops don't support meta backend and should have concrete data
# some ops don't support meta backend and should have concrete data
elem = concrete_data
elem = concrete_data
else :
else :
if meta_data is None :
if meta_data is None :
device = kwargs . get ( " device " , " cpu " )
with ConstructorManager . disable ( ) :
elem = func ( * args , * * { * * kwargs , " device " : " meta " } )
# to disable create lazy tensor in inner ops, this is a hack for torch 2.0
meta_data = MetaTensor ( elem , device = device )
meta_data = func ( * args , * * { * * kwargs , " device " : " meta " } )
elem = meta_data . _tensor
elem = meta_data
# As a meta tensor cannot be modified __class__ to torch.Tensor, we should use an empty real tensor here
# As a meta tensor cannot be modified __class__ to torch.Tensor, we should use an empty real tensor here
r = torch . Tensor . _make_subclass ( cls , _EMPTY_DATA , require_grad = elem . requires_grad )
r = torch . Tensor . _make_subclass ( cls , _EMPTY_DATA , require_grad = elem . requires_grad )
r . _meta_data = meta_data
r . _meta_data = meta_data
return r
return r
def __init__ ( self , func , * args , meta_data = None , concrete_data = None , * * kwargs ) :
def __init__ ( self , func , * args , meta_data = None , concrete_data = None , * * kwargs ) :
self . _device = torch . device ( kwargs . get ( " device " , None ) or " cpu " )
if func . __name__ in _NORMAL_FACTORY :
if func . __name__ in _NORMAL_FACTORY :
kwargs = { * * kwargs , " device " : LazyTensor . default_device }
kwargs = { * * kwargs , " device " : LazyTensor . default_device }
self . _factory_method = ( func , args , kwargs ) # (func, args, kwargs)
self . _factory_method = ( func , args , kwargs ) # (func, args, kwargs)
self . _op_buffer = [ ] # (func, args, kwargs, replace)
self . _op_buffer = [ ] # (func, args, kwargs, replace)
self . _materialized_data : Optional [ torch . Tensor ] = concrete_data # materialized data
self . _materialized_data : Optional [ torch . Tensor ] = concrete_data # materialized data
@property
def device ( self ) - > torch . device :
return self . _materialized_data . device if self . _materialized_data is not None else self . _device
def __repr__ ( self ) :
return f " LazyTensor(..., size= { tuple ( self . shape ) } , device= ' { self . device } ' , dtype= { self . dtype } ) "
def materialize ( self ) - > torch . Tensor :
def materialize ( self ) - > torch . Tensor :
""" Materialize the ``LazyTensor`` to ``torch.Tensor`` by modifying __class__ (inplace).
""" Materialize the ``LazyTensor`` to ``torch.Tensor`` by modifying __class__ (inplace).
@ -183,20 +215,6 @@ class LazyTensor(torch.Tensor):
self . clean ( )
self . clean ( )
return _convert_cls ( self , target )
return _convert_cls ( self , target )
def distribute ( self , device_mesh : DeviceMesh , sharding_spec : ShardingSpec ) - > torch . Tensor :
""" Distribute the ``LazyTensor`` to ``torch.Tensor`` by modifying __class__ (inplace), according to the layout.
Args :
layout ( Layout ) : Distribution layout .
Returns :
torch . Tensor : The distributed tensor ( self ) .
"""
target = self . _materialize_data ( )
self . clean ( )
local_tensor = distribute_tensor ( target , device_mesh , sharding_spec )
return _convert_cls ( self , local_tensor )
def clean ( self ) - > None :
def clean ( self ) - > None :
""" Clean all stored operations, meta data and materialized data, which prevents memory leaking. This should be called after all tensors are materialized. """
""" Clean all stored operations, meta data and materialized data, which prevents memory leaking. This should be called after all tensors are materialized. """
delattr ( self , " _factory_method " )
delattr ( self , " _factory_method " )
@ -299,45 +317,80 @@ class LazyTensor(torch.Tensor):
# for early materialized tensor, use its materialized data directly
# for early materialized tensor, use its materialized data directly
return x . _materialized_data if is_change_meta_op else x . _materialized_data . data
return x . _materialized_data if is_change_meta_op else x . _materialized_data . data
t = x if is_inplace else x . clone ( )
t = x if is_inplace else x . clone ( )
t . _op_buffer . append ( ( func , args , kwargs ) )
if func . __name__ not in _NO_RERUN_OPS :
t . _op_buffer . append ( ( func , args , kwargs ) )
meta = x . _meta_data if is_change_meta_op else x . _meta_data . data
meta = x . _meta_data if is_change_meta_op else x . _meta_data . data
meta_to_lazy [ meta ] = t
meta_to_lazy [ meta ] = t
return meta
return meta
elif (
version . parse ( torch . __version__ ) > = version . parse ( " 2.0.0 " )
and func . __name__ in _EXPAND_SCALAR_OPS
and not isinstance ( x , torch . Tensor )
) :
return _old_tensor_factory ( x , device = " meta " )
return x
return x
def wrap ( y , i = None ) :
def wrap ( y , i = None ) :
if isinstance ( y , MetaTensor ) :
if isinstance ( y , torch . Tensor ) :
if y in meta_to_lazy :
if y . is_meta :
# inplace op, just return origin lazy tensor
if y in meta_to_lazy :
return meta_to_lazy [ y ]
# inplace op, just return origin lazy tensor
return meta_to_lazy [ y ]
else :
# out of place op, create new lazy tensor
fn = lambda * a , * * kw : func ( * a , * * kw ) if i is None else func ( * a , * * kw ) [ i ]
fn . __name__ = func . __name__
lazy_y = LazyTensor ( fn , * args , meta_data = y , * * kwargs )
return lazy_y
else :
else :
# out of place op, create new lazy tensor
# for early materialized tensor
fn = lambda * a , * * kw : func ( * a , * * kw ) if i is None else func ( * a , * * kw ) [ i ]
return LazyTensor ( lambda : None , concrete_data = y )
fn . __name__ = func . __name__
lazy_y = LazyTensor ( fn , * args , meta_data = y , * * kwargs )
return lazy_y
elif type ( y ) is Tensor :
# for early materialized tensor
return LazyTensor ( lambda : None , concrete_data = y )
return y
return y
cls . _pre_op_fn ( )
cls . _pre_op_fn ( )
o = func ( * tree_map ( unwrap , args ) , * * tree_map ( unwrap , kwargs ) )
with ConstructorManager . disable ( ) :
# to disable create lazy tensor in inner ops, this is a hack for torch 2.0
o = func ( * tree_map ( unwrap , args ) , * * tree_map ( unwrap , kwargs ) )
if isinstance ( o , ( tuple , list ) ) :
if isinstance ( o , ( tuple , list ) ) :
return type ( o ) ( wrap ( y , i = i ) for i , y in enumerate ( o ) )
return type ( o ) ( wrap ( y , i = i ) for i , y in enumerate ( o ) )
return wrap ( o )
return wrap ( o )
@classmethod
def to ( self , * args , * * kwargs ) - > torch . Tensor :
def __torch_dispatch__ ( cls , func , types , args = ( ) , kwargs = None ) :
if self . _materialized_data is not None :
pass # skip
return LazyTensor ( lambda : None , concrete_data = self . _materialized_data . to ( * args , * * kwargs ) )
device = None
def replace ( x ) :
nonlocal device
if isinstance ( x , ( str , int , torch . device ) ) and not isinstance ( x , bool ) :
device = x
return torch . device ( " meta " )
return x
meta_data = self . _meta_data . to ( * tree_map ( replace , args ) , * * tree_map ( replace , kwargs ) )
if meta_data is self . _meta_data and device == self . device :
return self
def factory_fn ( t : torch . Tensor , * * kw ) :
return t . to ( * args , * * kwargs )
return LazyTensor ( factory_fn , self , meta_data = meta_data , device = device )
def cpu ( self , memory_format : torch . memory_format = torch . preserve_format ) :
return self . to ( device = torch . device ( " cpu " ) , memory_format = memory_format )
def cuda ( self , device = None , non_blocking = False , memory_format : torch . memory_format = torch . preserve_format ) :
device = torch . device ( device or " cuda " )
return self . to ( device = device , non_blocking = non_blocking , memory_format = memory_format )
def clone ( self ) - > " LazyTensor " :
def clone ( self ) - > " LazyTensor " :
def factory_fn ( ) :
def factory_fn ( t : torch . Tensor , * * kw ) :
# if self is materialized, return self
# if self is materialized, return self
new_tensor = self . materialize ( ) if type ( self ) is LazyTensor else self
return t . clone ( )
return new_tensor . clone ( )
target = LazyTensor ( factory_fn , meta_data = self . _meta_data )
target = LazyTensor ( factory_fn , self , meta_data = self . _meta_data )
return target
return target
@ -353,17 +406,16 @@ class LazyTensor(torch.Tensor):
if id ( self ) in memo :
if id ( self ) in memo :
return memo [ id ( self ) ]
return memo [ id ( self ) ]
def factory_fn ( ) :
def factory_fn ( t : torch . Tensor , * * kw ) :
# if self is materialized, return self
# if self is materialized, return self
new_tensor = self . materialize ( ) if type ( self ) is LazyTensor else self
return _copy_tensor ( t , t . requires_grad )
return _copy_tensor ( new_tensor , new_tensor . requires_grad )
if self . _materialized_data is not None :
if self . _materialized_data is not None :
# self is early materialized
# self is early materialized
copied = _copy_tensor ( self . _materialized_data , self . requires_grad )
copied = _copy_tensor ( self . _materialized_data , self . requires_grad )
target = LazyTensor ( lambda : None , concrete_data = copied )
target = LazyTensor ( lambda : None , concrete_data = copied )
else :
else :
target = LazyTensor ( factory_fn , meta_data = self . _meta_data )
target = LazyTensor ( factory_fn , self , meta_data = self . _meta_data )
if isinstance ( self , Parameter ) :
if isinstance ( self , Parameter ) :
# hack isinstance check of parameter
# hack isinstance check of parameter
@ -394,14 +446,12 @@ class LazyTensor(torch.Tensor):
if other is self :
if other is self :
return
return
self . _op_buffer . append ( other . _factory_method )
def replace ( x ) :
def replace ( x ) :
if x is other :
if x is other :
return self
return self
return x
return x
for func , args , kwargs in other . _op_buffer :
for func , args , kwargs in [ other . _factory_method , * other . _op_buffer ] :
self . _op_buffer . append ( ( func , tree_map ( replace , args ) , tree_map ( replace , kwargs ) ) )
self . _op_buffer . append ( ( func , tree_map ( replace , args ) , tree_map ( replace , kwargs ) ) )
def tolist ( self ) - > list :
def tolist ( self ) - > list :
@ -455,7 +505,6 @@ class LazyInitContext:
default_device : Optional [ Union [ torch . device , str , int ] ] = None ,
default_device : Optional [ Union [ torch . device , str , int ] ] = None ,
) :
) :
assert tensor_cls is LazyTensor or tensor_cls is _MyTensor
assert tensor_cls is LazyTensor or tensor_cls is _MyTensor
self . overrides = { }
self . tensor_cls = tensor_cls
self . tensor_cls = tensor_cls
self . old_default_device = LazyTensor . default_device
self . old_default_device = LazyTensor . default_device
self . default_device = default_device
self . default_device = default_device
@ -478,7 +527,9 @@ class LazyInitContext:
# factory_like functions (eg. torch.empty_like())
# factory_like functions (eg. torch.empty_like())
def wrapper ( * args , * * kwargs ) :
def wrapper ( * args , * * kwargs ) :
orig_t = args [ 0 ]
orig_t = args [ 0 ]
return self . tensor_cls ( orig_target , * args [ 1 : ] , device = orig_t . device , dtype = orig_t . dtype , * * kwargs )
return self . tensor_cls (
orig_target , * orig_t . shape , * args [ 1 : ] , device = orig_t . device , dtype = orig_t . dtype , * * kwargs
)
return wrapper , target
return wrapper , target
@ -513,13 +564,13 @@ class LazyInitContext:
return wrapper , target
return wrapper , target
self . overrides = {
overrides = {
target : wrap_factory_method ( getattr ( torch , target ) )
target : wrap_factory_method ( getattr ( torch , target ) )
for target in _NORMAL_FACTORY
for target in _NORMAL_FACTORY
if callable ( getattr ( torch , target , None ) )
if callable ( getattr ( torch , target , None ) )
}
}
self . overrides . update (
overrides . update (
{
{
target + " _like " : wrap_factory_like_method ( getattr ( torch , target ) , getattr ( torch , target + " _like " ) )
target + " _like " : wrap_factory_like_method ( getattr ( torch , target ) , getattr ( torch , target + " _like " ) )
for target in _NORMAL_FACTORY
for target in _NORMAL_FACTORY
@ -527,7 +578,7 @@ class LazyInitContext:
}
}
)
)
self . overrides . update (
overrides . update (
{
{
target : wrap_legacy_constructor ( getattr ( torch , target ) , dtype )
target : wrap_legacy_constructor ( getattr ( torch , target ) , dtype )
for target , dtype in _LEGACY_TENSOR_CONSTRUCTOR . items ( )
for target , dtype in _LEGACY_TENSOR_CONSTRUCTOR . items ( )
@ -535,7 +586,7 @@ class LazyInitContext:
}
}
)
)
self . overrides . update (
overrides . update (
{
{
target : wrap_no_meta_factory ( getattr ( torch , target ) )
target : wrap_no_meta_factory ( getattr ( torch , target ) )
for target in _NO_META_FACTORY
for target in _NO_META_FACTORY
@ -543,14 +594,12 @@ class LazyInitContext:
}
}
)
)
for name , ( wrapper , orig ) in self . overrides . items ( ) :
ConstructorManager . apply ( overrides )
setattr ( torch , name , wrapper )
def __exit__ ( self , exc_type , exc_val , exc_tb ) :
def __exit__ ( self , exc_type , exc_val , exc_tb ) :
self . tensor_cls . default_device = self . old_default_device
self . tensor_cls . default_device = self . old_default_device
LazyInitContext . _replaced = False
LazyInitContext . _replaced = False
for name , ( wrapper , orig ) in self . overrides . items ( ) :
ConstructorManager . clear ( )
setattr ( torch , name , orig )
@staticmethod
@staticmethod
def materialize ( module : nn . Module , verbose : bool = False ) - > nn . Module :
def materialize ( module : nn . Module , verbose : bool = False ) - > nn . Module :
@ -566,23 +615,6 @@ class LazyInitContext:
return _apply_to_lazy_module ( module , apply_fn , verbose )
return _apply_to_lazy_module ( module , apply_fn , verbose )
@staticmethod
def distribute (
module : nn . Module , device_mesh : DeviceMesh , sharding_spec_dict : Dict [ str , ShardingSpec ] , verbose : bool = False
) - > nn . Module :
""" Distribute all ``Parameter`` from ``LazyTensor``. This function will modify the module in-place.
Args :
module ( nn . Module ) : Target ` ` nn . Module ` `
layout_dict ( dict ) : Dict of layout for each parameter / buffer . The key is the parameter / buffer name , and the value is the layout .
verbose ( bool , optional ) : Whether to print lazy initialization rate . Defaults to False .
"""
def apply_fn ( name : str , p : LazyTensor ) :
p . distribute ( device_mesh , sharding_spec_dict [ name ] )
return _apply_to_lazy_module ( module , apply_fn , verbose )
def _apply_to_lazy_module (
def _apply_to_lazy_module (
module : nn . Module , apply_fn : Callable [ [ str , torch . Tensor ] , None ] , verbose : bool = False
module : nn . Module , apply_fn : Callable [ [ str , torch . Tensor ] , None ] , verbose : bool = False
@ -622,20 +654,17 @@ def _apply_to_lazy_module(
if verbose :
if verbose :
non_lazy_numel_ratio = non_lazy_numel / total_numel * 100 if non_lazy_numel != 0 else 0
non_lazy_numel_ratio = non_lazy_numel / total_numel * 100 if non_lazy_numel != 0 else 0
_print_rank_0 ( f " Param lazy rate: { param_lazy_cnt } / { param_cnt } " )
logger = get_dist_logger ( )
_print_rank_0 ( f " Buffer lazy rate: { buf_lazy_cnt } / { buf_cnt } " )
logger . info ( f " Param lazy rate: { param_lazy_cnt } / { param_cnt } " , ranks = [ 0 ] )
_print_rank_0 (
logger . info ( f " Buffer lazy rate: { buf_lazy_cnt } / { buf_cnt } " , ranks = [ 0 ] )
f " Non lazy numel: { non_lazy_numel } ( { non_lazy_numel / 1024 * * 2 : .3f } M), ratio: { non_lazy_numel_ratio } % "
logger . info (
f " Non lazy numel: { non_lazy_numel } ( { non_lazy_numel / 1024 * * 2 : .3f } M), ratio: { non_lazy_numel_ratio } % " ,
ranks = [ 0 ] ,
)
)
return module
return module
def _print_rank_0 ( * args , * * kwargs ) :
if not dist . is_initialized ( ) or dist . get_rank ( ) == 0 :
print ( * args , * * kwargs )
def _is_int_tuple ( args ) - > bool :
def _is_int_tuple ( args ) - > bool :
if not isinstance ( args , tuple ) :
if not isinstance ( args , tuple ) :
return False
return False