You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
ColossalAI/colossalai/auto_parallel/tensor_shard/utils/misc.py

68 lines
2.3 KiB

import functools
import torch
from colossalai.logging import get_dist_logger
from colossalai.tensor.sharding_spec import ShardingSpec, ShardingSpecException
__all__ = ['ignore_sharding_exception']
def ignore_sharding_exception(func):
"""
A function wrapper to handle the ShardingSpecException in the function.
If ShardingSpecException occurs, this function will return None.
Usage:
# mute the assertion error in the function
@ignore_sharding_exception
def do_something():
...
"""
@functools.wraps(func)
def wrapper(*args, **kwargs):
try:
logger = get_dist_logger()
rst = func(*args, **kwargs)
return rst
except ShardingSpecException as e:
logger.debug(e)
return None
return wrapper
def check_sharding_spec_validity(sharding_spec: ShardingSpec, tensor: torch.Tensor):
"""
This function checks whether the ShardingSpec is valid for the physical tensor.
This check includes 2 items:
1. the sharding spec covers all dimensions of the physical tensor
2. the sharding spec for each dimension is divisible by the number of devices.
#
"""
# make sure all dims are covered in sharding spec
sharding_len = len(sharding_spec.sharding_sequence)
tensor_num_dim = tensor.dim()
num_devices_in_col = sharding_spec.device_mesh.mesh_shape[0]
num_devices_in_row = sharding_spec.device_mesh.mesh_shape[1]
assert sharding_len == tensor_num_dim, \
f'The ShardingSpec ({sharding_spec.sharding_sequence}) is created for {sharding_len}-dimension tensor, but the given tensor is {tensor_num_dim}-dimension ({tensor.shape}).'
# make sure the sharding is valid for each dim
for i in range(tensor_num_dim):
dim_size = tensor.shape[i]
dim_spec = sharding_spec.sharding_sequence[i]
if str(dim_spec).startswith('S'):
devices_str = str(dim_spec).lstrip('S')
num_devices = 1
if '0' in devices_str:
num_devices *= num_devices_in_col
if '1' in devices_str:
num_devices *= num_devices_in_row
assert dim_size >= num_devices and dim_size % num_devices == 0, \
f'The dimension at index {i} has value {dim_size}, but it is sharded over {num_devices} devices.'