2022-10-14 05:27:00 +00:00
import functools
2022-10-19 09:08:52 +00:00
import torch
2022-10-19 04:53:06 +00:00
from colossalai . logging import get_dist_logger
2022-10-19 09:08:52 +00:00
from colossalai . tensor . sharding_spec import ShardingSpec , ShardingSpecException
2022-10-14 05:27:00 +00:00
2022-10-19 04:53:06 +00:00
__all__ = [ ' ignore_sharding_exception ' ]
2022-10-14 05:27:00 +00:00
2022-10-19 04:53:06 +00:00
def ignore_sharding_exception ( func ) :
2022-10-14 05:27:00 +00:00
"""
2022-10-19 04:53:06 +00:00
A function wrapper to handle the ShardingSpecException in the function .
If ShardingSpecException occurs , this function will return None .
2022-10-14 05:27:00 +00:00
Usage :
# mute the assertion error in the function
2022-10-19 04:53:06 +00:00
@ignore_sharding_exception
2022-10-14 05:27:00 +00:00
def do_something ( ) :
. . .
"""
@functools.wraps ( func )
def wrapper ( * args , * * kwargs ) :
try :
2022-10-19 04:53:06 +00:00
logger = get_dist_logger ( )
2022-10-14 05:27:00 +00:00
rst = func ( * args , * * kwargs )
return rst
2022-10-19 04:53:06 +00:00
except ShardingSpecException as e :
logger . debug ( e )
return None
2022-10-14 05:27:00 +00:00
return wrapper
2022-10-19 09:08:52 +00:00
def check_sharding_spec_validity ( sharding_spec : ShardingSpec , tensor : torch . Tensor ) :
"""
This function checks whether the ShardingSpec is valid for the physical tensor .
2022-10-20 07:18:16 +00:00
This check includes 3 items :
2022-10-19 09:08:52 +00:00
1. the sharding spec covers all dimensions of the physical tensor
2. the sharding spec for each dimension is divisible by the number of devices .
2022-10-20 07:18:16 +00:00
3. the sharding spec ' s entire shape must match the tensor shape
2022-10-19 09:08:52 +00:00
#
"""
# make sure all dims are covered in sharding spec
sharding_len = len ( sharding_spec . sharding_sequence )
tensor_num_dim = tensor . dim ( )
num_devices_in_col = sharding_spec . device_mesh . mesh_shape [ 0 ]
num_devices_in_row = sharding_spec . device_mesh . mesh_shape [ 1 ]
assert sharding_len == tensor_num_dim , \
f ' The ShardingSpec ( { sharding_spec . sharding_sequence } ) is created for { sharding_len } -dimension tensor, but the given tensor is { tensor_num_dim } -dimension ( { tensor . shape } ). '
# make sure the sharding is valid for each dim
for i in range ( tensor_num_dim ) :
dim_size = tensor . shape [ i ]
dim_spec = sharding_spec . sharding_sequence [ i ]
if str ( dim_spec ) . startswith ( ' S ' ) :
devices_str = str ( dim_spec ) . lstrip ( ' S ' )
num_devices = 1
if ' 0 ' in devices_str :
num_devices * = num_devices_in_col
if ' 1 ' in devices_str :
num_devices * = num_devices_in_row
assert dim_size > = num_devices and dim_size % num_devices == 0 , \
f ' The dimension at index { i } has value { dim_size } , but it is sharded over { num_devices } devices. '
2022-10-20 07:18:16 +00:00
# make sure the entire shape matches the physical tensor shape
2022-10-20 08:12:39 +00:00
assert sharding_spec . entire_shape == tensor . shape , \
f ' The entire_shape of the sharding spec { sharding_spec . entire_shape } does not match the tensor shape { tensor . shape } '