import functools import torch from colossalai.logging import get_dist_logger from colossalai.tensor.sharding_spec import ShardingSpec, ShardingSpecException __all__ = ['ignore_sharding_exception'] def ignore_sharding_exception(func): """ A function wrapper to handle the ShardingSpecException in the function. If ShardingSpecException occurs, this function will return None. Usage: # mute the assertion error in the function @ignore_sharding_exception def do_something(): ... """ @functools.wraps(func) def wrapper(*args, **kwargs): try: logger = get_dist_logger() rst = func(*args, **kwargs) return rst except ShardingSpecException as e: logger.debug(e) return None return wrapper def check_sharding_spec_validity(sharding_spec: ShardingSpec, tensor: torch.Tensor): """ This function checks whether the ShardingSpec is valid for the physical tensor. This check includes 2 items: 1. the sharding spec covers all dimensions of the physical tensor 2. the sharding spec for each dimension is divisible by the number of devices. # """ # make sure all dims are covered in sharding spec sharding_len = len(sharding_spec.sharding_sequence) tensor_num_dim = tensor.dim() num_devices_in_col = sharding_spec.device_mesh.mesh_shape[0] num_devices_in_row = sharding_spec.device_mesh.mesh_shape[1] assert sharding_len == tensor_num_dim, \ f'The ShardingSpec ({sharding_spec.sharding_sequence}) is created for {sharding_len}-dimension tensor, but the given tensor is {tensor_num_dim}-dimension ({tensor.shape}).' # make sure the sharding is valid for each dim for i in range(tensor_num_dim): dim_size = tensor.shape[i] dim_spec = sharding_spec.sharding_sequence[i] if str(dim_spec).startswith('S'): devices_str = str(dim_spec).lstrip('S') num_devices = 1 if '0' in devices_str: num_devices *= num_devices_in_col if '1' in devices_str: num_devices *= num_devices_in_row assert dim_size >= num_devices and dim_size % num_devices == 0, \ f'The dimension at index {i} has value {dim_size}, but it is sharded over {num_devices} devices.'