@ -36,8 +36,21 @@ class ChunkFullError(Exception):
pass
class Chunk :
def is_storage_empty ( tensor : torch . Tensor ) - > bool :
return tensor . storage ( ) . size ( ) == 0
def free_storage ( tensor : torch . Tensor ) - > None :
if not is_storage_empty ( tensor ) :
tensor . storage ( ) . resize_ ( 0 )
def alloc_storage ( tensor : torch . Tensor ) - > None :
if is_storage_empty ( tensor ) :
tensor . storage ( ) . resize_ ( tensor . numel ( ) )
class Chunk :
"""
A chunk is a contiguous memory space which contains multiple tensors .
@ -46,26 +59,37 @@ class Chunk:
src_rank ( int ) : the process which owns the chunk
dtype ( torch . dtype ) : the data type of the chunk
init_device ( torch . device ) : optional , the device where the tensor is initialized . The default value is None , which is the current GPU .
force_data_on_cuda ( bool ) : optional , if True , chunk . data is always on cuda . Defaults to False .
"""
def __init__ ( self ,
chunk_size : int ,
src_rank : int ,
dtype : torch . dtype ,
init_device : Optional [ torch . device ] = None ) - > None :
init_device : Optional [ torch . device ] = None ,
force_data_on_cuda : bool = False ) - > None :
self . size = chunk_size
self . utilized_size = 0
self . src_rank = src_rank
self . is_src_rank = gpc . get_local_rank ( ParallelMode . DATA ) == src_rank
self . global_src_rank = gpc . get_ranks_in_group ( ParallelMode . DATA ) [ src_rank ]
self . dtype = dtype
self . device = init_device or get_current_device ( )
self . data = torch . empty ( chunk_size , dtype = dtype , device = self . device )
device = init_device or get_current_device ( )
if force_data_on_cuda :
self . data = torch . empty ( chunk_size , dtype = dtype , device = get_current_device ( ) )
self . _cpu_data = torch . empty ( chunk_size , dtype = dtype )
if device . type == ' cuda ' :
free_storage ( self . _cpu_data )
else :
free_storage ( self . data )
else :
self . data = torch . empty ( chunk_size , dtype = dtype , device = device )
self . _cpu_data = None
# we only keep the chunk in full in the process by which the tensor is owned
if not self . is_src_rank :
self . data . storage ( ) . resize_ ( 0 )
free_storage ( self . _payload )
# each tensor is associated with a TensorInfo to track meta info
self . tensors_info : Dict [ torch . Tensor , TensorInfo ] = { }
self . mem = self . size * self . data . element_size ( )
@ -83,16 +107,16 @@ class Chunk:
# raise exception when the chunk size is exceeded
if new_utilized_size > self . size :
raise ChunkFullError
# set tensor state
tensor_state = TensorState . FREE
# if the process owns the rank, then copy the tensor to its chunk buffer
# otherwise set its storage size to 0 to reduce memory consumption
if self . is_src_rank :
self . data [ self . utilized_size : new_utilized_size ] . copy_ ( tensor . view ( - 1 ) )
self . _payload [ self . utilized_size : new_utilized_size ] . copy_ ( tensor . view ( - 1 ) )
tensor_state = TensorState . HOLD
tensor . data = self . data [ self . utilized_size : new_utilized_size ] . view ( tensor . shape )
tensor . data = self . _payload [ self . utilized_size : new_utilized_size ] . view ( tensor . shape )
else :
tensor . storage ( ) . resize_ ( 0 )
self . tensors_info [ tensor ] = TensorInfo ( tensor_state , self . utilized_size , new_utilized_size )
@ -103,12 +127,12 @@ class Chunk:
Release the memory space on processes which do not own the chunk .
"""
if not self . is_src_rank :
self . data . storage ( ) . resize_ ( 0 )
free_storage ( self . _payload )
self . _update_tensors_state ( TensorState . FREE )
def _update_tensors_ptr ( self ) - > None :
for tensor , tensor_info in self . tensors_info . items ( ) :
tensor . data = self . data [ tensor_info . offset : tensor_info . end ] . view ( tensor . shape )
tensor . data = self . _payload [ tensor_info . offset : tensor_info . end ] . view ( tensor . shape )
def _update_tensors_state ( self , next_state : TensorState , prev_state : Optional [ TensorState ] = None ) :
for tensor_info in self . tensors_info . values ( ) :
@ -122,8 +146,8 @@ class Chunk:
# recover the chunk on non-owner processes
# and broadcast the chunk from the source to all processes
if not self . is_src_rank :
self . data . storage ( ) . resize_ ( self . size )
self . data . data = self . data . to ( get_current_device ( ) )
alloc_storage ( self . _payload )
self . move_device ( get_current_device ( ) , update_ptr = False )
dist . broadcast ( self . data , self . global_src_rank , group = gpc . get_group ( ParallelMode . DATA ) )
# update tensor meta info
@ -131,15 +155,32 @@ class Chunk:
if not self . is_src_rank :
self . _update_tensors_state ( TensorState . HOLD , prev_state = TensorState . FREE )
def move_device ( self , device : torch . device ) - > None :
def move_device ( self , device : torch . device , update_ptr : bool = True ) - > None :
"""
Move the chunk to a target device .
Args :
device ( torch . device ) : the target device for data movement .
"""
self . data . data = self . data . to ( device )
self . _update_tensors_ptr ( )
if self . _payload . device == device :
return
if self . _cpu_data is None :
self . data . data = self . data . to ( device )
else :
if device . type == ' cuda ' :
# cpu -> cuda
src = self . _cpu_data
dest = self . data
else :
# cuda -> cpu
src = self . data
dest = self . _cpu_data
alloc_storage ( dest )
dest . copy_ ( src )
free_storage ( src )
if update_ptr :
self . _update_tensors_ptr ( )
def reduce ( self , is_all_reduce : bool = False ) - > None :
"""
@ -148,7 +189,7 @@ class Chunk:
Args :
is_all_reduce ( bool ) : optional , whether to all - reduce the chunk . The default is false .
"""
self . data . data = self . data . to ( get_current_device ( ) )
self . move_device ( get_current_device ( ) , update_ptr = False )
if is_all_reduce :
dist . all_reduce ( self . data , group = gpc . get_group ( ParallelMode . DATA ) )
else :
@ -187,8 +228,8 @@ class Chunk:
data_slice ( torch . Tensor ) : the tensor to be copied to the chunk
"""
tensor_info = self . tensors_info [ tensor ]
self . data [ tensor_info . offset : tensor_info . end ] . copy_ ( data_slice . view ( - 1 ) )
tensor . data = self . data [ tensor_info . offset : tensor_info . end ] . view ( tensor . shape )
self . _payload [ tensor_info . offset : tensor_info . end ] . copy_ ( data_slice . view ( - 1 ) )
tensor . data = self . _payload [ tensor_info . offset : tensor_info . end ] . view ( tensor . shape )
@property
def can_release ( self ) - > bool :
@ -225,7 +266,7 @@ class Chunk:
"""
Check whether the chunk is empty .
"""
return self . data . storage ( ) . size ( ) == 0
return is_storage_empty ( self . _payload )
def __repr__ ( self ) - > str :
return f ' Chunk: src rank= { self . src_rank } ,size= { self . size } , utilization= { self . utilized_size / self . size * 100 : .2f } %, freed= { self . is_empty } , tensor states= { [ info . state . name for info in self . tensors_info . values ( ) ] } '
@ -235,8 +276,8 @@ class Chunk:
"""
Check if the chunk has inf or nan values .
"""
return torch . isinf ( self . data [ : self . utilized_size ] ) . any ( ) . item ( ) or \
torch . isnan ( self . data [ : self . utilized_size ] ) . any ( ) . item ( )
return torch . isinf ( self . _payload [ : self . utilized_size ] ) . any ( ) . item ( ) or \
torch . isnan ( self . _payload [ : self . utilized_size ] ) . any ( ) . item ( )
def copy_ ( self , dest_chunk : ' Chunk ' ) :
"""
@ -246,7 +287,7 @@ class Chunk:
assert not dest_chunk . is_empty
assert self . size == dest_chunk . size
assert self . utilized_size == dest_chunk . utilized_size
self . data . copy_ ( dest_chunk . data )
self . _payload . copy_ ( dest_chunk . _payload )
self . _update_tensors_ptr ( )
@property
@ -254,7 +295,7 @@ class Chunk:
"""
Get the device type of the chunk .
"""
return self . data . device . type
return self . _payload . device . type
def __hash__ ( self ) - > int :
return hash ( id ( self ) )
@ -265,6 +306,12 @@ class Chunk:
def get_tensors ( self ) - > List [ torch . Tensor ] :
return list ( self . tensors_info . keys ( ) )
@property
def _payload ( self ) - > torch . Tensor :
if self . _cpu_data is None or is_storage_empty ( self . _cpu_data ) :
return self . data
return self . _cpu_data
class ChunkManager :
"""
@ -285,6 +332,7 @@ class ChunkManager:
self . enable_distributed_storage = enable_distributed_storage
self . device = init_device or get_current_device ( )
self . chunk_groups : Dict [ str , Deque [ Chunk ] ] = { }
self . groups_force_data_on_cuda : Dict [ str , bool ] = { }
self . tensor_chunk_map : Dict [ torch . Tensor , Chunk ] = { }
self . accessed_chunks : Set [ Chunk ] = set ( )
self . lazy_release_tensors : List [ torch . Tensor ] = [ ]
@ -292,6 +340,17 @@ class ChunkManager:
self . rank_load : Dict [ str , torch . Tensor ] = { }
self . total_mem : Dict [ str , int ] = { ' cpu ' : 0 , ' cuda ' : 0 }
def create_group ( self , group_name : str , force_data_on_cuda : bool = False ) - > None :
""" Create a chunk group.
Args :
group_name ( str ) : group name
force_data_on_cuda ( bool , optional ) : If True , the data of chunks in this group is always on cuda . . Defaults to False .
"""
assert group_name not in self . chunk_groups
self . chunk_groups [ group_name ] = deque ( )
self . groups_force_data_on_cuda [ group_name ] = force_data_on_cuda
def append_tensor ( self , tensor : torch . Tensor , group_name : str ) - > None :
"""
Append a tensor to a chunk .
@ -304,19 +363,20 @@ class ChunkManager:
if self . chunk_size is not None and tensor . numel ( ) > self . chunk_size :
raise ValueError (
f ' Cannot create chunk, got tensor numel ( { tensor . numel ( ) } ) > chunk size ( { self . chunk_size } ) ' )
if group_name not in self . chunk_groups :
self . chunk_groups [ group_name ] = deque ( )
try :
# append the tensor to the last chunk
self . chunk_groups [ group_name ] [ - 1 ] . append ( tensor )
except ( IndexError , ChunkFullError ) :
# the except statement will be triggered when there is no chunk or
# the except statement will be triggered when there is no chunk or
# the last chunk in the chunk group is full
# this will create a new chunk and allocate this chunk to its corresponding process
chunk_size = self . chunk_size or tensor . numel ( )
src_rank = self . _get_next_src_rank ( group_name )
chunk = Chunk ( chunk_size , src_rank , tensor . dtype , self . device )
chunk = Chunk ( chunk_size ,
src_rank ,
tensor . dtype ,
self . device ,
force_data_on_cuda = self . groups_force_data_on_cuda [ group_name ] )
if self . enable_distributed_storage and self . chunk_size is None :
self . rank_load [ group_name ] [ src_rank ] + = chunk_size
@ -387,7 +447,7 @@ class ChunkManager:
# update the memory consumption after releasing
self . total_mem [ chunk . device_type ] - = chunk . mem
def move_chunk ( self , chunk : Chunk , device : torch . device ) - > None :
def move_chunk ( self , chunk : Chunk , device : torch . device , update_ptr : bool = True ) - > None :
"""
Move the chunk to the target device .
@ -399,7 +459,7 @@ class ChunkManager:
return
if chunk . can_move_device and not chunk . is_empty :
self . total_mem [ chunk . device_type ] - = chunk . mem
chunk . move_device ( device )
chunk . move_device ( device , update_ptr = update_ptr )
self . total_mem [ chunk . device_type ] + = chunk . mem
def trans_tensor_state ( self , tensor : torch . Tensor , state : TensorState ) - > None :