[autoparallel] add experimental view handler (#2011)

* [autoparallel] add experimental view handler

* polish

* polish

* polish code

* rename variables
pull/2018/head
YuliangLiu0306 2022-11-24 11:34:41 +08:00 committed by GitHub
parent d655eea515
commit 1438993113
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 459 additions and 2 deletions

View File

@ -0,0 +1,4 @@
from .view_generator import ViewGenerator
from .view_handler import ViewHandler
__all__ = ['ViewGenerator', 'ViewHandler']

View File

@ -0,0 +1,133 @@
import copy
from typing import List
from colossalai.auto_parallel.tensor_shard.node_handler.strategy.strategy_generator import FollowingStrategyGenerator
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
CommAction,
CommType,
MemoryCost,
ShardingStrategy,
TrainCycleItem,
)
from colossalai.auto_parallel.tensor_shard.utils import (
check_keep_sharding_status,
detect_reshape_mapping,
infer_output_dim_partition_dict,
)
from colossalai.tensor.shape_consistency import CollectiveCommPattern
from colossalai.tensor.sharding_spec import ShardingSpec
__all__ = ['ViewGenerator']
class ViewGenerator(FollowingStrategyGenerator):
"""
ViewGenerator which deals with the sharding strategies of view op.
"""
def validate(self) -> bool:
return super().validate()
def update_compute_cost(self, strategy: ShardingStrategy):
compute_cost = TrainCycleItem(fwd=10, bwd=10, total=20)
strategy.compute_cost = compute_cost
def update_memory_cost(self, strategy: ShardingStrategy):
'''
Compute the memory cost per device with this specific strategy.
'''
forward_size_mapping = {
'input': self._compute_size_in_bytes(strategy, "input"),
'output': self._compute_size_in_bytes(strategy, "output")
}
backward_size_mapping = copy.deepcopy(forward_size_mapping)
backward_size_mapping.pop("output")
# compute fwd cost incurred
# fwd_cost = input + output
fwd_activation_cost = sum([v for k, v in forward_size_mapping.items() if not self.is_param(k)])
fwd_parameter_cost = sum([v for k, v in forward_size_mapping.items() if self.is_param(k)])
fwd_mem_cost = MemoryCost(activation=fwd_activation_cost, parameter=fwd_parameter_cost)
# compute bwd cost incurred
# bwd_cost = input_grad
bwd_activation_cost = sum([v for k, v in backward_size_mapping.items() if not self.is_param(k)])
bwd_parameter_cost = sum([v for k, v in backward_size_mapping.items() if self.is_param(k)])
bwd_mem_cost = MemoryCost(activation=bwd_activation_cost, parameter=bwd_parameter_cost)
# compute total cost
total_mem_cost = MemoryCost(activation=fwd_activation_cost + bwd_activation_cost,
parameter=fwd_parameter_cost + bwd_parameter_cost)
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
strategy.memory_cost = memory_cost
def collate_strategies(self) -> List[ShardingStrategy]:
strategy_list = []
for index, strategy in enumerate(self.predecessor_node.strategies_vector):
dim_partition_dict_mapping = {}
communication_action_mapping = {}
input_sharding_spec = strategy.output_sharding_specs[self.op_data["input"]]
origin_shape = self.op_data['input'].data.shape
tgt_shape = self.op_data['tgt_shape'].data
reshape_mapping_dict = detect_reshape_mapping(origin_shape, tgt_shape)
dim_partition_dict_for_input = input_sharding_spec.dim_partition_dict
keep_sharding_status = check_keep_sharding_status(dim_partition_dict_for_input, reshape_mapping_dict)
if keep_sharding_status:
dim_partition_dict_for_output = infer_output_dim_partition_dict(dim_partition_dict_for_input,
reshape_mapping_dict)
else:
dim_partition_dict_for_output = {}
dim_partition_dict_mapping = {
"input": dim_partition_dict_for_input,
"output": dim_partition_dict_for_output,
}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
# add index into name to pass the duplicated check
# we keep same strategies with different name for node merging, and it will not increase the searching space,
# because in solver, this node will be merged into other nodes, and solver will not create a new variable for this node.
if keep_sharding_status:
name = f'{sharding_spec_mapping["input"].sharding_sequence} -> {sharding_spec_mapping["output"].sharding_sequence}_{index}'
else:
name = f'{sharding_spec_mapping["input"].sharding_sequence} -> FULLY REPLICATED_{index}'
# add comm action for converting input to fully replicated
total_mesh_dim_list = []
for mesh_dim_list in dim_partition_dict_for_input.values():
total_mesh_dim_list.extend(mesh_dim_list)
# if there is only one sharding dimension, we should use the value instead of list as logical_process_axis.
if len(total_mesh_dim_list) == 1:
total_mesh_dim_list = total_mesh_dim_list[0]
input_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping["input"],
communication_pattern=CollectiveCommPattern.GATHER_FWD_SPLIT_BWD,
logical_process_axis=total_mesh_dim_list,
comm_type=CommType.BEFORE,
arg_index=0)
input_comm_action.comm_spec.gather_dim = total_mesh_dim_list
elif len(total_mesh_dim_list) >= 2:
source_spec = sharding_spec_mapping["input"]
target_spec = ShardingSpec(device_mesh=self.device_mesh,
entire_shape=source_spec.entire_shape,
dim_partition_dict={})
comm_spec = {'src_spec': source_spec, 'tgt_spec': target_spec}
input_comm_action = CommAction(comm_spec=comm_spec, comm_type=CommType.BEFORE, arg_index=0)
else:
input_comm_action = None
if input_comm_action is not None:
communication_action_mapping["input"] = input_comm_action
strategy = self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
strategy_list.append(strategy)
return strategy_list

View File

@ -0,0 +1,51 @@
from typing import Dict, List
import torch
from ...sharding_strategy import OperationData, OperationDataType
from ..node_handler import NodeHandler
from ..registry import operator_registry
from ..strategy import StrategyGenerator
from .view_generator import ViewGenerator
__all__ = ['ViewHandler']
@operator_registry.register(torch.Tensor.view)
class ViewHandler(NodeHandler):
"""
A ViewHandler which deals with the sharding strategies for Reshape Op, such as torch.reshape.
"""
def get_strategy_generator(self) -> List[StrategyGenerator]:
op_data_mapping = self.get_operation_data_mapping()
generators = []
generators.append(ViewGenerator(op_data_mapping, self.device_mesh, self.node.args[0]))
return generators
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
# use transposed shape for strategies
# the strategies will be transformed back to its original shape in self.post_process
# check if the input operand is a parameter
if isinstance(self.node.args[0]._meta_data, torch.nn.parameter.Parameter):
data_type = OperationDataType.PARAM
else:
data_type = OperationDataType.ARG
input_data = self.node.args[0]._meta_data
physical_input_operand = OperationData(name=str(self.node.args[0]), type=data_type, data=input_data)
target_shape = self.node._meta_data.shape
physical_shape_operand = OperationData(name='tgt_shape', type=OperationDataType.ARG, data=target_shape)
output_data = self.node._meta_data
physical_output_operand = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=output_data)
mapping = {
"input": physical_input_operand,
"tgt_shape": physical_shape_operand,
"output": physical_output_operand
}
return mapping

View File

@ -51,6 +51,8 @@ class OperationData:
"""
if isinstance(data, torch.Tensor):
return data.shape
elif isinstance(data, torch.Size):
return None
elif isinstance(data, (tuple, list)):
data_type = type(data)
return data_type([_infer_logical_shape(d) for d in data])

View File

@ -82,7 +82,6 @@ class StrategiesConstructor:
for node in self.nodes:
strategies_vector = StrategiesVector(node)
print(node)
if _check_no_strategy_for_node(node):
no_strategy_node.append(node)
pass

View File

@ -7,6 +7,7 @@ from .broadcast import (
)
from .factory import generate_resharding_costs, generate_sharding_spec
from .misc import check_sharding_spec_validity, ignore_sharding_exception, pytree_map
from .reshape import check_keep_sharding_status, detect_reshape_mapping, infer_output_dim_partition_dict
from .sharding import (
enumerate_all_possible_1d_sharding,
enumerate_all_possible_2d_sharding,
@ -19,5 +20,6 @@ __all__ = [
'BroadcastType', 'get_broadcast_shape', 'is_broadcastable', 'recover_sharding_spec_for_broadcast_shape',
'generate_resharding_costs', 'generate_sharding_spec', 'ignore_sharding_exception', 'check_sharding_spec_validity'
'transpose_partition_dim', 'update_partition_dim', 'enumerate_all_possible_1d_sharding',
'enumerate_all_possible_2d_sharding', 'generate_sharding_size', 'comm_actions_for_oprands', 'pytree_map'
'enumerate_all_possible_2d_sharding', 'generate_sharding_size', 'comm_actions_for_oprands', 'pytree_map',
'detect_reshape_mapping', 'check_keep_sharding_status', 'infer_output_dim_partition_dict'
]

View File

@ -0,0 +1,168 @@
from enum import Enum
from typing import Dict, List, Tuple
import torch
class PreviousStatus(Enum):
"""
This class shows the status of previous comparision.
"""
RESET = 0
# ORIGIN means the dimension size of original tensor is larger in the previous comparision.
ORIGIN = 1
# TGT means the dimension size of target tensor is larger in the previous comparision.
TGT = 2
def detect_reshape_mapping(origin_shape: torch.Size, tgt_shape: torch.Size) -> Dict[Tuple[int], Tuple[int]]:
"""
This method is used to detect the reshape mapping between original tensor and target tensor.
Returns:
reshape_mapping_dict: The dictionary shows how a tuple of origin dims(keys) mapping to the related
target dims(values) during reshaping operation.
Examples:
import torch
origin_shape = torch.Size([4, 4, 4])
tgt_shape = torch.Size([2, 8, 2, 2])
reshape_mapping_dict = detect_reshape_mapping(origin_shape, tgt_shape)
print(reshape_mapping_dict)
Output:
{(2,): (3, 2), (1, 0): (1,), (0,): (0, 1)}
"""
# reverse the shape object
origin_shape = list(origin_shape)
tgt_shape = list(tgt_shape)
origin_shape.reverse()
tgt_shape.reverse()
# initialize arguments
reshape_mapping_dict = {}
origin_len = len(origin_shape)
tgt_len = len(tgt_shape)
origin_index = 0
tgt_index = 0
original_dimension_size = origin_shape[origin_index]
tgt_dimension_size = tgt_shape[tgt_index]
tgt_dims = [tgt_len - tgt_index - 1]
origin_dims = [origin_len - origin_index - 1]
previous_label = PreviousStatus.RESET
while origin_index != len(origin_shape) or tgt_index != len(tgt_shape):
if original_dimension_size == tgt_dimension_size:
reshape_mapping_dict[tuple(origin_dims)] = tuple(tgt_dims)
origin_index += 1
tgt_index += 1
# the last step of loop should always end with condition
# so we need to manually skip the preparation for next step
# in the last step.
if origin_index == len(origin_shape):
continue
original_dimension_size = origin_shape[origin_index]
tgt_dimension_size = tgt_shape[tgt_index]
origin_dims = [origin_len - origin_index - 1]
tgt_dims = [tgt_len - tgt_index - 1]
previous_label = PreviousStatus.RESET
elif original_dimension_size > tgt_dimension_size:
tgt_index += 1
if previous_label == PreviousStatus.TGT:
# if the target dimension size is larger in the previous comparision, which means
# the origin dimension size has already accumulated larger than target dimension size, so
# we need to offload the origin dims and tgt dims into the reshape_mapping_dict.
reshape_mapping_dict[tuple(origin_dims)] = tuple(tgt_dims)
original_dimension_size = original_dimension_size // tgt_dimension_size
origin_dims = [origin_len - origin_index - 1]
tgt_dimension_size = tgt_shape[tgt_index]
tgt_dims = [tgt_len - tgt_index - 1, tgt_len - tgt_index]
# reset the previous_label after offloading the origin dims and tgt dims
previous_label = PreviousStatus.RESET
else:
# accumulate the tgt_dimension_size until tgt_dimension_size larger than original_dimension_size
tgt_dimension_size *= tgt_shape[tgt_index]
tgt_dims.append(tgt_len - tgt_index - 1)
previous_label = PreviousStatus.ORIGIN
else:
origin_index += 1
if previous_label == PreviousStatus.ORIGIN:
# if the origin element is larger in the previous comparision, which means
# the target element has already accumulated larger than origin element, so
# we need to offload the origin dims and tgt dims into the reshape_mapping_dict.
reshape_mapping_dict[tuple(origin_dims)] = tuple(tgt_dims)
tgt_dimension_size = tgt_dimension_size // original_dimension_size
tgt_dims = [tgt_len - tgt_index - 1]
original_dimension_size = origin_shape[origin_index]
origin_dims = [origin_len - origin_index - 1, origin_len - origin_index]
# reset the previous_label after offloading the origin dims and tgt dims
previous_label = PreviousStatus.RESET
else:
# accumulate the original_dimension_size until original_dimension_size larger than tgt_dimension_size
original_dimension_size *= origin_shape[origin_index]
origin_dims.append(origin_len - origin_index - 1)
previous_label = PreviousStatus.TGT
return reshape_mapping_dict
def check_keep_sharding_status(input_dim_partition_dict: Dict[int, List[int]],
reshape_mapping_dict: Dict[Tuple[int], Tuple[int]]) -> bool:
"""
This method is used to check whether the reshape operation could implement without converting
the input to fully replicated status.
Rule:
For a sharded dimension of input tensor, if it is not the minimum element of the input tuple,
the function will return false.
To illustrate this issue, there are two cases to analyse:
1. no sharded dims in the input tuple: we could do the reshape operation safely just as the normal
operation without distributed tensor.
2. sharded dims in the input tuple: the sharded dim must be the minimum element, then during shape
consistency process, torch.cat will be implemented on the sharded dim, and everything after the sharded
dim get recovered.
Examples:
# the second dimension of the input has been sharded.
input_dim_partition_dict = {1: [1]}
origin_shape = torch.Size([8, 4, 2])
tgt_shape = torch.Size([2, 4, 8])
reshape_mapping_dict = detect_reshape_mapping(origin_shape, tgt_shape)
# {(2, 1): (2,), (0,): (1, 0)}
# the sharded dim of input is 1, which is the minimum element of the tuple (2, 1),
# so we do not have to convert the input to fully replicated status.
print(check_keep_sharding_status(input_dim_partition_dict, reshape_mapping_dict))
Output:
True
"""
sharded_dims = list(input_dim_partition_dict.keys())
for input_dims in reshape_mapping_dict.keys():
min_element = min(input_dims)
for dim in input_dims:
if dim in sharded_dims and dim is not min_element:
return False
return True
def infer_output_dim_partition_dict(input_dim_partition_dict: Dict[int, List[int]],
reshape_mapping_dict: Dict[Tuple[int], Tuple[int]]) -> Dict[Tuple[int], Tuple[int]]:
"""
This method is used to infer the output dim partition dict for a reshape operation,
given the input dim partition dict and reshape mapping dict.
"""
assert check_keep_sharding_status(input_dim_partition_dict, reshape_mapping_dict), \
'we only infer output dim partition dict for the reshape operation could keep sharding spec.'
sharded_dims = list(input_dim_partition_dict.keys())
output_dim_partition_dict = {}
for input_dims, output_dims in reshape_mapping_dict.items():
for dim in input_dims:
if dim in sharded_dims:
output_dim_partition_dict[min(output_dims)] = input_dim_partition_dict[dim]
# we could break because input dims cannot contain two sharded dims, otherwise
# the keep sharding status check will fail.
break
return output_dim_partition_dict

View File

@ -0,0 +1,98 @@
import torch
import torch.nn as nn
from colossalai.auto_parallel.tensor_shard.node_handler.conv_handler import ConvFunctionHandler
from colossalai.auto_parallel.tensor_shard.node_handler.experimental import ViewHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.testing.pytest_wrapper import run_on_environment_flag
class ViewModel(nn.Module):
def __init__(self):
super().__init__()
def forward(self, input, other):
conv_node = nn.functional.conv2d(input, other)
reshape_node = conv_node.view(32, 4, 32, 32, 4)
return reshape_node
def test_view_handler():
model = ViewModel()
tracer = ColoTracer()
# graph():
# %input_1 : torch.Tensor [#users=1] = placeholder[target=input]
# %other : torch.Tensor [#users=1] = placeholder[target=other]
# %conv2d : [#users=1] = call_function[target=torch.conv2d](args = (%input_1, %other), kwargs = {})
# %view : [#users=1] = call_method[target=view](args = (%conv2d, 2, -1), kwargs = {})
# return view
graph = tracer.trace(model,
meta_args={
"input": torch.rand(8, 8, 66, 66).to('meta'),
"other": torch.rand(16, 8, 3, 3).to('meta'),
})
gm = ColoGraphModule(model, graph)
physical_mesh_id = torch.arange(0, 4)
mesh_shape = (2, 2)
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
conv_mod_node = list(graph.nodes)[2]
view_node = list(graph.nodes)[3]
view_strategies_vector = StrategiesVector(view_node)
conv_strategies_vector = StrategiesVector(conv_mod_node)
# build handler
conv_handler = ConvFunctionHandler(node=conv_mod_node,
device_mesh=device_mesh,
strategies_vector=conv_strategies_vector)
conv_handler.register_strategy(compute_resharding_cost=False)
setattr(conv_mod_node, 'strategies_vector', conv_strategies_vector)
view_handler = ViewHandler(node=view_node, device_mesh=device_mesh, strategies_vector=view_strategies_vector)
view_handler.register_strategy(compute_resharding_cost=False)
# check operation data mapping
mapping = view_handler.get_operation_data_mapping()
for name, op_data in mapping.items():
op_data: OperationData
# make sure they have valid values
assert op_data.data is not None
assert mapping['input'].name == "conv2d"
assert mapping['input'].data.is_meta
assert mapping['input'].data.shape == torch.Size([8, 16, 64, 64])
assert mapping['input'].type == OperationDataType.ARG
assert mapping['input'].logical_shape == torch.Size([8, 16, 64, 64])
assert mapping['output'].name == "view"
assert mapping['output'].data.is_meta
assert mapping['output'].data.shape == torch.Size([32, 4, 32, 32, 4])
assert mapping['output'].type == OperationDataType.OUTPUT
# reshape handler is a following strategy handler, so the number of strategies is equal to the predecessor node.
assert len(view_strategies_vector) == len(conv_strategies_vector)
strategy_name_list = [strategy.name for strategy in view_strategies_vector]
assert '[S0, S1, R, R] -> FULLY REPLICATED_0' in strategy_name_list
assert '[S1, S0, R, R] -> FULLY REPLICATED_1' in strategy_name_list
assert '[S0, R, R, R] -> [S0, R, R, R, R]_2' in strategy_name_list
assert '[S1, R, R, R] -> [S1, R, R, R, R]_3' in strategy_name_list
assert '[S0, R, R, R] -> [S0, R, R, R, R]_4' in strategy_name_list
assert '[S1, R, R, R] -> [S1, R, R, R, R]_5' in strategy_name_list
assert '[R, S1, R, R] -> FULLY REPLICATED_6' in strategy_name_list
assert '[R, S0, R, R] -> FULLY REPLICATED_7' in strategy_name_list
assert '[R, R, R, R] -> [R, R, R, R, R]_8' in strategy_name_list
assert '[R, R, R, R] -> [R, R, R, R, R]_9' in strategy_name_list
assert '[R, S0, R, R] -> FULLY REPLICATED_10' in strategy_name_list
assert '[R, S1, R, R] -> FULLY REPLICATED_11' in strategy_name_list
assert '[R, R, R, R] -> [R, R, R, R, R]_12' in strategy_name_list
assert '[S01, R, R, R] -> [S01, R, R, R, R]_13' in strategy_name_list
assert '[R, R, R, R] -> [R, R, R, R, R]_14' in strategy_name_list
assert '[R, S01, R, R] -> FULLY REPLICATED_15' in strategy_name_list
if __name__ == '__main__':
test_view_handler()