ColossalAI/colossalai/auto_parallel/tensor_shard/utils/reshape.py

193 lines
8.9 KiB
Python

from enum import Enum
from typing import Dict, List, Tuple
import torch
class PreviousStatus(Enum):
"""
This class shows the status of previous comparison.
"""
RESET = 0
# ORIGIN means the dimension size of original tensor is larger in the previous comparison.
ORIGIN = 1
# TGT means the dimension size of target tensor is larger in the previous comparison.
TGT = 2
def detect_reshape_mapping(origin_shape: torch.Size, tgt_shape: torch.Size) -> Dict[Tuple[int], Tuple[int]]:
"""
This method is used to detect the reshape mapping between original tensor and target tensor.
Returns:
reshape_mapping_dict: The dictionary shows how a tuple of origin dims(keys) mapping to the related
target dims(values) during reshaping operation.
Examples:
import torch
origin_shape = torch.Size([4, 4, 4])
tgt_shape = torch.Size([2, 8, 2, 2])
reshape_mapping_dict = detect_reshape_mapping(origin_shape, tgt_shape)
print(reshape_mapping_dict)
Output:
{(2,): (3, 2), (1, 0): (1,), (0,): (0, 1)}
"""
# reverse the shape object
origin_shape = list(origin_shape)
tgt_shape = list(tgt_shape)
origin_shape.reverse()
tgt_shape.reverse()
# initialize arguments
reshape_mapping_dict = {}
origin_len = len(origin_shape)
tgt_len = len(tgt_shape)
origin_index = 0
tgt_index = 0
original_dimension_size = origin_shape[origin_index]
tgt_dimension_size = tgt_shape[tgt_index]
tgt_dims = [tgt_len - tgt_index - 1]
origin_dims = [origin_len - origin_index - 1]
previous_label = PreviousStatus.RESET
while origin_index != len(origin_shape) or tgt_index != len(tgt_shape):
if original_dimension_size == tgt_dimension_size:
reshape_mapping_dict[tuple(origin_dims)] = tuple(tgt_dims)
# if the origin_dims has no element, it means the original tensor has been fully matched.
# Therefore, we do not have to increase the origin_index for that case.
if len(origin_dims) > 0:
origin_index += 1
# if the tgt_dims has no element, it means the original tensor has been fully matched.
# Therefore, we do not have to increase the tgt_index for that case.
if len(tgt_dims) > 0:
tgt_index += 1
# the last step of loop should always end with condition
# so we need to manually skip the preparation for next step
# in the last step.
if origin_index == len(origin_shape) and tgt_index == len(tgt_shape):
continue
# If origin_index equals to origin_len, we just need to set the original_dimension_size
# to 1 to match the remaining '1's in the target tensor shape.
if origin_index == len(origin_shape):
original_dimension_size = 1
origin_dims = []
else:
original_dimension_size = origin_shape[origin_index]
origin_dims = [origin_len - origin_index - 1]
# If tgt_index equals to tgt_len, we just need to set the tgt_dimension_size
# to 1 to match the remaining '1's in the original tensor shape.
if tgt_index == len(tgt_shape):
tgt_dimension_size = 1
tgt_dims = []
else:
tgt_dimension_size = tgt_shape[tgt_index]
tgt_dims = [tgt_len - tgt_index - 1]
previous_label = PreviousStatus.RESET
elif original_dimension_size > tgt_dimension_size:
tgt_index += 1
if previous_label == PreviousStatus.TGT:
# if the target dimension size is larger in the previous comparison, which means
# the origin dimension size has already accumulated larger than target dimension size, so
# we need to offload the origin dims and tgt dims into the reshape_mapping_dict.
reshape_mapping_dict[tuple(origin_dims)] = tuple(tgt_dims)
original_dimension_size = original_dimension_size // tgt_dimension_size
origin_dims = [origin_len - origin_index - 1]
tgt_dimension_size = tgt_shape[tgt_index]
tgt_dims = [tgt_len - tgt_index - 1, tgt_len - tgt_index]
# reset the previous_label after offloading the origin dims and tgt dims
previous_label = PreviousStatus.RESET
else:
# accumulate the tgt_dimension_size until tgt_dimension_size larger than original_dimension_size
tgt_dimension_size *= tgt_shape[tgt_index]
tgt_dims.append(tgt_len - tgt_index - 1)
previous_label = PreviousStatus.ORIGIN
else:
origin_index += 1
if previous_label == PreviousStatus.ORIGIN:
# if the origin element is larger in the previous comparison, which means
# the target element has already accumulated larger than origin element, so
# we need to offload the origin dims and tgt dims into the reshape_mapping_dict.
reshape_mapping_dict[tuple(origin_dims)] = tuple(tgt_dims)
tgt_dimension_size = tgt_dimension_size // original_dimension_size
tgt_dims = [tgt_len - tgt_index - 1]
original_dimension_size = origin_shape[origin_index]
origin_dims = [origin_len - origin_index - 1, origin_len - origin_index]
# reset the previous_label after offloading the origin dims and tgt dims
previous_label = PreviousStatus.RESET
else:
# accumulate the original_dimension_size until original_dimension_size larger than tgt_dimension_size
original_dimension_size *= origin_shape[origin_index]
origin_dims.append(origin_len - origin_index - 1)
previous_label = PreviousStatus.TGT
return reshape_mapping_dict
def check_keep_sharding_status(input_dim_partition_dict: Dict[int, List[int]],
reshape_mapping_dict: Dict[Tuple[int], Tuple[int]]) -> bool:
"""
This method is used to check whether the reshape operation could implement without converting
the input to fully replicated status.
Rule:
For a sharded dimension of input tensor, if it is not the minimum element of the input tuple,
the function will return false.
To illustrate this issue, there are two cases to analyze:
1. no sharded dims in the input tuple: we could do the reshape operation safely just as the normal
operation without distributed tensor.
2. sharded dims in the input tuple: the sharded dim must be the minimum element, then during shape
consistency process, torch.cat will be implemented on the sharded dim, and everything after the sharded
dim get recovered.
Examples:
# the second dimension of the input has been sharded.
input_dim_partition_dict = {1: [1]}
origin_shape = torch.Size([8, 4, 2])
tgt_shape = torch.Size([2, 4, 8])
reshape_mapping_dict = detect_reshape_mapping(origin_shape, tgt_shape)
# {(2, 1): (2,), (0,): (1, 0)}
# the sharded dim of input is 1, which is the minimum element of the tuple (2, 1),
# so we do not have to convert the input to fully replicated status.
print(check_keep_sharding_status(input_dim_partition_dict, reshape_mapping_dict))
Output:
True
"""
sharded_dims = list(input_dim_partition_dict.keys())
for input_dims in reshape_mapping_dict.keys():
# if input_dims has no element, we could just skip this iteration.
if len(input_dims) == 0:
continue
min_element = min(input_dims)
for dim in input_dims:
if dim in sharded_dims and dim is not min_element:
return False
return True
def infer_output_dim_partition_dict(input_dim_partition_dict: Dict[int, List[int]],
reshape_mapping_dict: Dict[Tuple[int], Tuple[int]]) -> Dict[Tuple[int], Tuple[int]]:
"""
This method is used to infer the output dim partition dict for a reshape operation,
given the input dim partition dict and reshape mapping dict.
"""
assert check_keep_sharding_status(input_dim_partition_dict, reshape_mapping_dict), \
'we only infer output dim partition dict for the reshape operation could keep sharding spec.'
sharded_dims = list(input_dim_partition_dict.keys())
output_dim_partition_dict = {}
for input_dims, output_dims in reshape_mapping_dict.items():
for dim in input_dims:
if dim in sharded_dims:
output_dim_partition_dict[min(output_dims)] = input_dim_partition_dict[dim]
# we could break because input dims cannot contain two sharded dims, otherwise
# the keep sharding status check will fail.
break
return output_dim_partition_dict