mirror of https://github.com/hpcaitech/ColossalAI
101 lines
3.8 KiB
Python
101 lines
3.8 KiB
Python
import functools
|
|
from typing import Any, Callable, Tuple, Type, Union
|
|
|
|
import torch
|
|
|
|
from colossalai.logging import get_dist_logger
|
|
from colossalai.tensor.sharding_spec import ShardingSpec, ShardingSpecException
|
|
|
|
__all__ = ["ignore_sharding_exception", "pytree_map"]
|
|
|
|
|
|
def ignore_sharding_exception(func):
|
|
"""
|
|
A function wrapper to handle the ShardingSpecException in the function.
|
|
If ShardingSpecException occurs, this function will return None.
|
|
|
|
Usage:
|
|
# mute the assertion error in the function
|
|
@ignore_sharding_exception
|
|
def do_something():
|
|
...
|
|
"""
|
|
|
|
@functools.wraps(func)
|
|
def wrapper(*args, **kwargs):
|
|
try:
|
|
logger = get_dist_logger()
|
|
rst = func(*args, **kwargs)
|
|
return rst
|
|
except ShardingSpecException as e:
|
|
logger.debug(e)
|
|
return None
|
|
|
|
return wrapper
|
|
|
|
|
|
def check_sharding_spec_validity(sharding_spec: ShardingSpec, tensor: torch.Tensor):
|
|
"""
|
|
This function checks whether the ShardingSpec is valid for the physical tensor.
|
|
This check includes 3 items:
|
|
1. the sharding spec covers all dimensions of the physical tensor
|
|
2. the sharding spec for each dimension is divisible by the number of devices.
|
|
3. the sharding spec's entire shape must match the tensor shape
|
|
#
|
|
"""
|
|
# make sure all dims are covered in sharding spec
|
|
sharding_len = len(sharding_spec.sharding_sequence)
|
|
tensor_num_dim = tensor.dim()
|
|
num_devices_in_col = sharding_spec.device_mesh.shape[0]
|
|
num_devices_in_row = sharding_spec.device_mesh.shape[1]
|
|
assert (
|
|
sharding_len == tensor_num_dim
|
|
), f"The ShardingSpec ({sharding_spec.sharding_sequence}) is created for {sharding_len}-dimension tensor, but the given tensor is {tensor_num_dim}-dimension ({tensor.shape})."
|
|
|
|
# make sure the sharding is valid for each dim
|
|
for i in range(tensor_num_dim):
|
|
dim_size = tensor.shape[i]
|
|
dim_spec = sharding_spec.sharding_sequence[i]
|
|
|
|
if str(dim_spec).startswith("S"):
|
|
devices_str = str(dim_spec).lstrip("S")
|
|
num_devices = 1
|
|
|
|
if "0" in devices_str:
|
|
num_devices *= num_devices_in_col
|
|
if "1" in devices_str:
|
|
num_devices *= num_devices_in_row
|
|
|
|
assert (
|
|
dim_size >= num_devices and dim_size % num_devices == 0
|
|
), f"The dimension at index {i} has value {dim_size}, but it is sharded over {num_devices} devices."
|
|
|
|
# make sure the entire shape matches the physical tensor shape
|
|
assert (
|
|
sharding_spec.entire_shape == tensor.shape
|
|
), f"The entire_shape of the sharding spec {sharding_spec.entire_shape} does not match the tensor shape {tensor.shape}"
|
|
|
|
|
|
def pytree_map(obj: Any, fn: Callable, process_types: Union[Type, Tuple[Type]] = (), map_all: bool = False) -> Any:
|
|
"""process object recursively, like pytree
|
|
|
|
Args:
|
|
obj (:class:`Any`): object to process
|
|
fn (:class:`Callable`): a function to process subobject in obj
|
|
process_types (:class: `type | tuple[type]`): types to determine the type to process
|
|
map_all (:class: `bool`): if map_all is True, then any type of element will use fn
|
|
|
|
Returns:
|
|
:class:`Any`: returns have the same structure of `obj` and type in process_types after map of `fn`
|
|
"""
|
|
if isinstance(obj, dict):
|
|
return {k: pytree_map(obj[k], fn, process_types, map_all) for k in obj}
|
|
elif isinstance(obj, tuple):
|
|
return tuple(pytree_map(o, fn, process_types, map_all) for o in obj)
|
|
elif isinstance(obj, list):
|
|
return list(pytree_map(o, fn, process_types, map_all) for o in obj)
|
|
elif isinstance(obj, process_types):
|
|
return fn(obj)
|
|
else:
|
|
return fn(obj) if map_all else obj
|