import functools
from typing import Any , Callable , Dict , List , Tuple , Type , Union
import torch
from colossalai . logging import get_dist_logger
from colossalai . tensor . sharding_spec import ShardingSpec , ShardingSpecException
__all__ = [ ' ignore_sharding_exception ' , ' pytree_map ' ]
def ignore_sharding_exception ( func ) :
"""
A function wrapper to handle the ShardingSpecException in the function .
If ShardingSpecException occurs , this function will return None .
Usage :
# mute the assertion error in the function
@ignore_sharding_exception
def do_something ( ) :
. . .
"""
@functools.wraps ( func )
def wrapper ( * args , * * kwargs ) :
try :
logger = get_dist_logger ( )
rst = func ( * args , * * kwargs )
return rst
except ShardingSpecException as e :
logger . debug ( e )
return None
return wrapper
def check_sharding_spec_validity ( sharding_spec : ShardingSpec , tensor : torch . Tensor ) :
"""
This function checks whether the ShardingSpec is valid for the physical tensor .
This check includes 3 items :
1. the sharding spec covers all dimensions of the physical tensor
2. the sharding spec for each dimension is divisible by the number of devices .
3. the sharding spec ' s entire shape must match the tensor shape
#
"""
# make sure all dims are covered in sharding spec
sharding_len = len ( sharding_spec . sharding_sequence )
tensor_num_dim = tensor . dim ( )
num_devices_in_col = sharding_spec . device_mesh . mesh_shape [ 0 ]
num_devices_in_row = sharding_spec . device_mesh . mesh_shape [ 1 ]
assert sharding_len == tensor_num_dim , \
f ' The ShardingSpec ( { sharding_spec . sharding_sequence } ) is created for { sharding_len } -dimension tensor, but the given tensor is { tensor_num_dim } -dimension ( { tensor . shape } ). '
# make sure the sharding is valid for each dim
for i in range ( tensor_num_dim ) :
dim_size = tensor . shape [ i ]
dim_spec = sharding_spec . sharding_sequence [ i ]
if str ( dim_spec ) . startswith ( ' S ' ) :
devices_str = str ( dim_spec ) . lstrip ( ' S ' )
num_devices = 1
if ' 0 ' in devices_str :
num_devices * = num_devices_in_col
if ' 1 ' in devices_str :
num_devices * = num_devices_in_row
assert dim_size > = num_devices and dim_size % num_devices == 0 , \
f ' The dimension at index { i } has value { dim_size } , but it is sharded over { num_devices } devices. '
# make sure the entire shape matches the physical tensor shape
assert sharding_spec . entire_shape == tensor . shape , \
f ' The entire_shape of the sharding spec { sharding_spec . entire_shape } does not match the tensor shape { tensor . shape } '
def pytree_map ( obj : Any , fn : Callable , process_types : Union [ Type , Tuple [ Type ] ] = ( ) , map_all : bool = False ) - > Any :
""" process object recursively, like pytree
Args :
obj ( : class : ` Any ` ) : object to process
fn ( : class : ` Callable ` ) : a function to process subobject in obj
process_types ( : class : ` type | tuple [ type ] ` ) : types to determine the type to process
map_all ( : class : ` bool ` ) : if map_all is True , then any type of element will use fn
Returns :
: class : ` Any ` : returns have the same structure of ` obj ` and type in process_types after map of ` fn `
"""
if isinstance ( obj , dict ) :
return { k : pytree_map ( obj [ k ] , fn , process_types , map_all ) for k in obj }
elif isinstance ( obj , tuple ) :
return tuple ( pytree_map ( o , fn , process_types , map_all ) for o in obj )
elif isinstance ( obj , list ) :
return list ( pytree_map ( o , fn , process_types , map_all ) for o in obj )
elif isinstance ( obj , process_types ) :
return fn ( obj )
else :
return fn ( obj ) if map_all else obj