2022-09-29 03:22:29 +00:00
|
|
|
from enum import Enum, auto
|
|
|
|
from typing import List
|
2022-10-21 10:55:48 +00:00
|
|
|
|
|
|
|
import torch
|
2022-11-04 10:36:42 +00:00
|
|
|
from torch.fx.node import Node
|
|
|
|
|
|
|
|
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
|
|
|
|
CommAction,
|
|
|
|
CommType,
|
|
|
|
OperationData,
|
|
|
|
OperationDataType,
|
|
|
|
)
|
|
|
|
from colossalai.tensor.comm_spec import CollectiveCommPattern, CommSpec
|
2022-09-29 03:22:29 +00:00
|
|
|
from colossalai.tensor.sharding_spec import ShardingSpec
|
|
|
|
|
2022-11-04 10:36:42 +00:00
|
|
|
__all__ = [
|
|
|
|
'BroadcastType', 'is_broadcastable', 'get_broadcast_shape', 'recover_sharding_spec_for_broadcast_shape',
|
|
|
|
'comm_actions_for_oprands'
|
|
|
|
]
|
2022-09-29 03:22:29 +00:00
|
|
|
|
|
|
|
|
|
|
|
class BroadcastType(Enum):
|
|
|
|
EQUAL = auto()
|
|
|
|
PADDDING = auto()
|
|
|
|
MULTIPLE = auto()
|
|
|
|
|
|
|
|
|
|
|
|
def is_broadcastable(shape1: torch.Size, shape2: torch.Size) -> bool:
|
|
|
|
"""
|
|
|
|
Check if two shapes are broadcastable to each other.
|
|
|
|
"""
|
|
|
|
for s1, s2 in zip(shape1[::-1], shape2[::-1]):
|
|
|
|
if s1 == 1 or s2 == 1 or s1 == s2:
|
|
|
|
pass
|
|
|
|
else:
|
|
|
|
return False
|
|
|
|
return True
|
|
|
|
|
|
|
|
|
|
|
|
def get_broadcast_shape(shape1: torch.Size, shape2: torch.Size) -> List[int]:
|
|
|
|
"""
|
|
|
|
Compute the broadcast shape given two shapes.
|
|
|
|
"""
|
|
|
|
assert is_broadcastable(shape1, shape2), f'{shape1} and {shape2} are not broadcastable'
|
|
|
|
shape1_reverse = shape1[::-1]
|
|
|
|
shape2_reverse = shape2[::-1]
|
|
|
|
min_common_dim = min(len(shape1), len(shape2))
|
|
|
|
dims = []
|
|
|
|
for s1, s2 in zip(shape1_reverse, shape2_reverse):
|
|
|
|
dims.append(max(s1, s2))
|
|
|
|
|
|
|
|
# append the remaining dims
|
|
|
|
dims.extend(shape1_reverse[min_common_dim:])
|
|
|
|
dims.extend(shape2_reverse[min_common_dim:])
|
|
|
|
return dims[::-1]
|
|
|
|
|
|
|
|
|
2022-11-01 07:14:53 +00:00
|
|
|
def get_broadcast_dim_info(logical_shape, physical_shape):
|
2022-09-29 03:22:29 +00:00
|
|
|
# get the number of dimensions
|
|
|
|
logical_num_dims = len(logical_shape)
|
|
|
|
physical_num_dims = len(physical_shape)
|
|
|
|
|
2022-10-21 10:55:48 +00:00
|
|
|
assert logical_num_dims >= physical_num_dims, \
|
|
|
|
'The number of dimensions in the logical shape is smaller than that of the physical shape, this tensor is not broadcast!'
|
|
|
|
|
2022-09-29 03:22:29 +00:00
|
|
|
# track the dim and its broadcasting type
|
|
|
|
logical_dim_broadcast_info = {}
|
|
|
|
|
|
|
|
for i in range(logical_num_dims):
|
|
|
|
# get the trailing dim size
|
|
|
|
logical_dim_idx = logical_num_dims - i - 1
|
|
|
|
phyiscal_dim_idx = physical_num_dims - i - 1
|
|
|
|
logical_dim_size = logical_shape[logical_dim_idx]
|
|
|
|
|
|
|
|
if phyiscal_dim_idx >= 0:
|
|
|
|
physical_dim_size = physical_shape[phyiscal_dim_idx]
|
|
|
|
|
|
|
|
if physical_dim_size == logical_dim_size:
|
|
|
|
logical_dim_broadcast_info[logical_dim_idx] = BroadcastType.EQUAL
|
|
|
|
elif physical_dim_size == 1 and physical_dim_size != logical_dim_size:
|
|
|
|
logical_dim_broadcast_info[logical_dim_idx] = BroadcastType.MULTIPLE
|
|
|
|
else:
|
|
|
|
logical_dim_broadcast_info[logical_dim_idx] = BroadcastType.PADDDING
|
|
|
|
|
2022-11-01 07:14:53 +00:00
|
|
|
return logical_dim_broadcast_info
|
|
|
|
|
|
|
|
|
|
|
|
def recover_sharding_spec_for_broadcast_shape(logical_sharding_spec: ShardingSpec, logical_shape: torch.Size,
|
|
|
|
physical_shape: torch.Size) -> ShardingSpec:
|
|
|
|
"""
|
|
|
|
This function computes the sharding spec for the physical shape of a broadcast tensor.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
logical_sharding_spec (ShardingSpec): the sharding spec for the broadcast tensor
|
|
|
|
logical_shape (torch.Size): logical shape is the broadcast shape of a tensor
|
|
|
|
physical_shape (torch.Size): the shape of the tensor before broadcasting
|
|
|
|
"""
|
|
|
|
# if the two shapes are the same, no broadcast occurs
|
|
|
|
# we directly return the current sharding spec
|
2022-11-04 10:36:42 +00:00
|
|
|
|
|
|
|
# recording the sharding dimensions removed during logical shape converting to physical one
|
|
|
|
removed_dims = []
|
2022-11-01 07:14:53 +00:00
|
|
|
if list(logical_shape) == list(physical_shape):
|
2022-11-04 10:36:42 +00:00
|
|
|
return logical_sharding_spec, removed_dims
|
2022-11-01 07:14:53 +00:00
|
|
|
|
|
|
|
# get the number of dimensions
|
|
|
|
logical_num_dims = len(logical_shape)
|
|
|
|
physical_num_dims = len(physical_shape)
|
|
|
|
|
|
|
|
# get the broadcast info
|
|
|
|
logical_dim_broadcast_info = get_broadcast_dim_info(logical_shape, physical_shape)
|
|
|
|
|
2022-09-29 03:22:29 +00:00
|
|
|
# generate the sharding spec for the physical shape
|
|
|
|
physical_dim_partition = {}
|
|
|
|
logical_dim_partition = logical_sharding_spec.dim_partition_dict
|
|
|
|
|
|
|
|
for shape_dim, mesh_dim in logical_dim_partition.items():
|
|
|
|
logical_broadcast_type = logical_dim_broadcast_info[shape_dim]
|
|
|
|
|
|
|
|
if logical_broadcast_type == BroadcastType.PADDDING or logical_broadcast_type == BroadcastType.MULTIPLE:
|
2022-11-04 10:36:42 +00:00
|
|
|
removed_dims.extend(mesh_dim)
|
2022-09-29 03:22:29 +00:00
|
|
|
else:
|
|
|
|
# get the corresponding physical dim
|
|
|
|
physical_dim = physical_num_dims - (logical_num_dims - shape_dim)
|
|
|
|
physical_dim_partition[physical_dim] = mesh_dim
|
|
|
|
|
|
|
|
physical_sharding_spec = ShardingSpec(device_mesh=logical_sharding_spec.device_mesh,
|
|
|
|
entire_shape=physical_shape,
|
|
|
|
dim_partition_dict=physical_dim_partition)
|
|
|
|
|
2022-11-04 10:36:42 +00:00
|
|
|
return physical_sharding_spec, removed_dims
|
|
|
|
|
|
|
|
|
|
|
|
def comm_actions_for_oprands(node: Node, removed_dims: List[int], op_data: OperationData,
|
|
|
|
sharding_spec: ShardingSpec) -> CommAction:
|
|
|
|
"""
|
|
|
|
This method is used to generate communication actions for oprands which lose information
|
|
|
|
during convert logical shape to physical shape.
|
|
|
|
"""
|
|
|
|
if len(removed_dims) == 1:
|
|
|
|
# if list length is 1, extract element from list to avoid using flatten device mesh
|
|
|
|
removed_dims = removed_dims[0]
|
|
|
|
comm_spec = CommSpec(comm_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
|
|
|
|
sharding_spec=sharding_spec,
|
|
|
|
logical_process_axis=removed_dims)
|
|
|
|
if op_data.type == OperationDataType.PARAM:
|
|
|
|
comm_type = CommType.HOOK
|
|
|
|
else:
|
|
|
|
comm_type = CommType.BEFORE
|
|
|
|
arg_index = -1
|
|
|
|
for index, arg in enumerate(node.args):
|
|
|
|
if op_data.name == str(arg):
|
|
|
|
arg_index = index
|
|
|
|
assert arg_index >= 0, f'op_data should be an argument of node.'
|
|
|
|
comm_action = CommAction(
|
|
|
|
comm_spec=comm_spec,
|
|
|
|
comm_type=comm_type,
|
|
|
|
arg_index=arg_index,
|
|
|
|
)
|
|
|
|
return comm_action
|