mirror of https://github.com/hpcaitech/ColossalAI
73 lines
2.6 KiB
Python
73 lines
2.6 KiB
Python
import functools
|
|
|
|
import torch
|
|
|
|
from colossalai.logging import get_dist_logger
|
|
from colossalai.tensor.sharding_spec import ShardingSpec, ShardingSpecException
|
|
|
|
__all__ = ['ignore_sharding_exception']
|
|
|
|
|
|
def ignore_sharding_exception(func):
|
|
"""
|
|
A function wrapper to handle the ShardingSpecException in the function.
|
|
If ShardingSpecException occurs, this function will return None.
|
|
|
|
Usage:
|
|
# mute the assertion error in the function
|
|
@ignore_sharding_exception
|
|
def do_something():
|
|
...
|
|
"""
|
|
|
|
@functools.wraps(func)
|
|
def wrapper(*args, **kwargs):
|
|
try:
|
|
logger = get_dist_logger()
|
|
rst = func(*args, **kwargs)
|
|
return rst
|
|
except ShardingSpecException as e:
|
|
logger.debug(e)
|
|
return None
|
|
|
|
return wrapper
|
|
|
|
|
|
def check_sharding_spec_validity(sharding_spec: ShardingSpec, tensor: torch.Tensor):
|
|
"""
|
|
This function checks whether the ShardingSpec is valid for the physical tensor.
|
|
This check includes 3 items:
|
|
1. the sharding spec covers all dimensions of the physical tensor
|
|
2. the sharding spec for each dimension is divisible by the number of devices.
|
|
3. the sharding spec's entire shape must match the tensor shape
|
|
#
|
|
"""
|
|
# make sure all dims are covered in sharding spec
|
|
sharding_len = len(sharding_spec.sharding_sequence)
|
|
tensor_num_dim = tensor.dim()
|
|
num_devices_in_col = sharding_spec.device_mesh.mesh_shape[0]
|
|
num_devices_in_row = sharding_spec.device_mesh.mesh_shape[1]
|
|
assert sharding_len == tensor_num_dim, \
|
|
f'The ShardingSpec ({sharding_spec.sharding_sequence}) is created for {sharding_len}-dimension tensor, but the given tensor is {tensor_num_dim}-dimension ({tensor.shape}).'
|
|
|
|
# make sure the sharding is valid for each dim
|
|
for i in range(tensor_num_dim):
|
|
dim_size = tensor.shape[i]
|
|
dim_spec = sharding_spec.sharding_sequence[i]
|
|
|
|
if str(dim_spec).startswith('S'):
|
|
devices_str = str(dim_spec).lstrip('S')
|
|
num_devices = 1
|
|
|
|
if '0' in devices_str:
|
|
num_devices *= num_devices_in_col
|
|
if '1' in devices_str:
|
|
num_devices *= num_devices_in_row
|
|
|
|
assert dim_size >= num_devices and dim_size % num_devices == 0, \
|
|
f'The dimension at index {i} has value {dim_size}, but it is sharded over {num_devices} devices.'
|
|
|
|
# make sure the entire shape matches the physical tensor shape
|
|
assert sharding_spec.entire_shape == tensor.shape, \
|
|
f'The entire_shape of the sharding spec {sharding_spec.entire_shape} does not match the tensor shape {tensor.shape}'
|