[autoparallel] refactor handlers which reshape input tensors (#2615)

* [autoparallel] refactor handlers which reshape input tensors

* polish
pull/2664/head
YuliangLiu0306 2023-02-08 15:02:49 +08:00 committed by GitHub
parent 28398f1c70
commit 37df666f38
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 307 additions and 365 deletions

View File

@ -3,8 +3,8 @@ from .batch_norm_handler import BatchNormModuleHandler
from .binary_elementwise_handler import BinaryElementwiseHandler
from .bmm_handler import AddBMMFunctionHandler, BMMFunctionHandler
from .conv_handler import ConvFunctionHandler, ConvModuleHandler
from .default_reshape_handler import DefaultReshapeHandler
from .embedding_handler import EmbeddingFunctionHandler, EmbeddingModuleHandler
from .experimental import PermuteHandler, ViewHandler
from .getattr_handler import GetattrHandler
from .getitem_handler import GetItemHandler
from .layer_norm_handler import LayerNormModuleHandler
@ -13,20 +13,24 @@ from .matmul_handler import MatMulHandler
from .normal_pooling_handler import NormPoolingHandler
from .option import ShardOption
from .output_handler import OutputHandler
from .permute_handler import PermuteHandler
from .placeholder_handler import PlaceholderHandler
from .registry import operator_registry
from .reshape_handler import ReshapeHandler
from .softmax_handler import SoftmaxHandler
from .split_handler import SplitHandler
from .sum_handler import SumHandler
from .tensor_constructor_handler import TensorConstructorHandler
from .transpose_handler import TransposeHandler
from .unary_elementwise_handler import UnaryElementwiseHandler
from .view_handler import ViewHandler
from .where_handler import WhereHandler
__all__ = [
'LinearFunctionHandler', 'LinearModuleHandler', 'BMMFunctionHandler', 'AddBMMFunctionHandler',
'LayerNormModuleHandler', 'BatchNormModuleHandler', 'ConvModuleHandler', 'ConvFunctionHandler',
'UnaryElementwiseHandler', 'ReshapeHandler', 'PlaceholderHandler', 'OutputHandler', 'WhereHandler',
'UnaryElementwiseHandler', 'DefaultReshapeHandler', 'PlaceholderHandler', 'OutputHandler', 'WhereHandler',
'NormPoolingHandler', 'BinaryElementwiseHandler', 'MatMulHandler', 'operator_registry', 'ADDMMFunctionHandler',
'GetItemHandler', 'GetattrHandler', 'ViewHandler', 'PermuteHandler', 'TensorConstructorHandler',
'EmbeddingModuleHandler', 'EmbeddingFunctionHandler', 'SumHandler', 'SoftmaxHandler', 'ShardOption'
'EmbeddingModuleHandler', 'EmbeddingFunctionHandler', 'SumHandler', 'SoftmaxHandler', 'ShardOption',
'TransposeHandler', 'SplitHandler'
]

View File

@ -5,23 +5,23 @@ import torch
from ..sharding_strategy import OperationData, OperationDataType
from .node_handler import MetaInfoNodeHandler, NodeHandler
from .registry import operator_registry
from .strategy import ReshapeGenerator, StrategyGenerator
from .strategy import DefaultReshapeGenerator, StrategyGenerator
__all__ = ['ReshapeHandler']
__all__ = ['DefaultReshapeHandler']
@operator_registry.register(torch.flatten)
@operator_registry.register(torch.Tensor.unsqueeze)
@operator_registry.register(torch.nn.AdaptiveAvgPool2d)
class ReshapeHandler(MetaInfoNodeHandler):
class DefaultReshapeHandler(MetaInfoNodeHandler):
"""
A ReshapeHandler which deals with the sharding strategies for Reshape Op, such as torch.reshape.
A DefaultReshapeHandler which deals with the sharding strategies for Reshape Op, such as torch.reshape.
"""
def get_strategy_generator(self) -> List[StrategyGenerator]:
op_data_mapping = self.get_operation_data_mapping()
generators = []
generators.append(ReshapeGenerator(op_data_mapping, self.device_mesh, self.node.args[0]))
generators.append(DefaultReshapeGenerator(op_data_mapping, self.device_mesh, self.node.args[0]))
return generators
def infer_logical_shape(self, data):

View File

@ -1,10 +0,0 @@
from .permute_handler import PermuteHandler
from .reshape_generator import PermuteGenerator, SplitGenerator, TransposeGenerator, ViewGenerator
from .split_handler import SplitHandler
from .transpose_handler import TransposeHandler
from .view_handler import ViewHandler
__all__ = [
'ViewGenerator', 'ViewHandler', 'PermuteGenerator', 'PermuteHandler', 'TransposeGenerator', 'TransposeGenerator',
'SplitHandler', 'SplitGenerator'
]

View File

@ -1,299 +0,0 @@
import copy
from typing import List
from colossalai.auto_parallel.tensor_shard.node_handler.strategy.strategy_generator import FollowingStrategyGenerator
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
CommAction,
CommType,
MemoryCost,
ShardingStrategy,
TrainCycleItem,
)
from colossalai.auto_parallel.tensor_shard.utils import (
check_keep_sharding_status,
detect_reshape_mapping,
infer_output_dim_partition_dict,
)
from colossalai.tensor.shape_consistency import CollectiveCommPattern
from colossalai.tensor.sharding_spec import ShardingSpec
__all__ = ['ReshapeGenerator', 'ViewGenerator', 'PermuteGenerator', 'TransposeGenerator', 'SplitGenerator']
class ReshapeGenerator(FollowingStrategyGenerator):
"""
ReshapeGenerator is the base class for all the reshape operation.
"""
def validate(self) -> bool:
return super().validate()
def update_compute_cost(self, strategy: ShardingStrategy):
compute_cost = TrainCycleItem(fwd=10, bwd=10, total=20)
strategy.compute_cost = compute_cost
def update_memory_cost(self, strategy: ShardingStrategy):
'''
Compute the memory cost per device with this specific strategy.
'''
forward_size_mapping = {
'input': self._compute_size_in_bytes(strategy, "input"),
'output': self._compute_size_in_bytes(strategy, "output")
}
backward_size_mapping = copy.deepcopy(forward_size_mapping)
backward_size_mapping.pop("output")
# compute fwd cost incurred
# fwd_cost = input + output
fwd_activation_cost = sum([v for k, v in forward_size_mapping.items() if not self.is_param(k)])
fwd_parameter_cost = sum([v for k, v in forward_size_mapping.items() if self.is_param(k)])
fwd_mem_cost = MemoryCost(activation=fwd_activation_cost, parameter=fwd_parameter_cost)
# compute bwd cost incurred
# bwd_cost = input_grad
bwd_activation_cost = sum([v for k, v in backward_size_mapping.items() if not self.is_param(k)])
bwd_parameter_cost = sum([v for k, v in backward_size_mapping.items() if self.is_param(k)])
bwd_mem_cost = MemoryCost(activation=bwd_activation_cost, parameter=bwd_parameter_cost)
# compute total cost
total_mem_cost = MemoryCost(activation=fwd_activation_cost + bwd_activation_cost,
parameter=fwd_parameter_cost + bwd_parameter_cost)
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
strategy.memory_cost = memory_cost
def collate_strategies(self) -> List[ShardingStrategy]:
return super().collate_strategies()
class ViewGenerator(ReshapeGenerator):
"""
ViewGenerator deals with the sharding strategies of view op.
"""
def collate_strategies(self) -> List[ShardingStrategy]:
strategy_list = []
for index, strategy in enumerate(self.predecessor_node.strategies_vector):
dim_partition_dict_mapping = {}
communication_action_mapping = {}
input_sharding_spec = strategy.output_sharding_specs[self.op_data["input"]]
origin_shape = self.op_data['input'].data.shape
tgt_shape = self.op_data['tgt_shape'].data
reshape_mapping_dict = detect_reshape_mapping(origin_shape, tgt_shape)
dim_partition_dict_for_input = input_sharding_spec.dim_partition_dict
keep_sharding_status = check_keep_sharding_status(dim_partition_dict_for_input, reshape_mapping_dict)
if keep_sharding_status:
dim_partition_dict_for_output = infer_output_dim_partition_dict(dim_partition_dict_for_input,
reshape_mapping_dict)
else:
dim_partition_dict_for_output = {}
dim_partition_dict_mapping = {
"input": dim_partition_dict_for_input,
"output": dim_partition_dict_for_output,
}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
# add index into name to pass the duplicated check
# we keep same strategies with different name for node merging, and it will not increase the searching space,
# because in solver, this node will be merged into other nodes, and solver will not create a new variable for this node.
if keep_sharding_status:
name = f'{sharding_spec_mapping["input"].sharding_sequence} -> {sharding_spec_mapping["output"].sharding_sequence}_{index}'
else:
name = f'{sharding_spec_mapping["input"].sharding_sequence} -> FULLY REPLICATED_{index}'
# add comm action for converting input to fully replicated
total_mesh_dim_list = []
for mesh_dim_list in dim_partition_dict_for_input.values():
total_mesh_dim_list.extend(mesh_dim_list)
# if there is only one sharding dimension, we should use the value instead of list as logical_process_axis.
if len(total_mesh_dim_list) == 1:
total_mesh_dim_list = total_mesh_dim_list[0]
# the total mesh dim list only has one element, so the shard dim has only one element as well.
shard_dim = list(dim_partition_dict_for_input.keys())[0]
input_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping["input"],
communication_pattern=CollectiveCommPattern.GATHER_FWD_SPLIT_BWD,
logical_process_axis=total_mesh_dim_list,
comm_type=CommType.BEFORE,
arg_index=0)
# it will gather the input through gather_dim during forward phase.
input_comm_action.comm_spec.gather_dim = shard_dim
# it will split the input activation grad through shard_dim during backward phase.
input_comm_action.comm_spec.shard_dim = shard_dim
elif len(total_mesh_dim_list) >= 2:
source_spec = sharding_spec_mapping["input"]
target_spec = ShardingSpec(device_mesh=self.device_mesh,
entire_shape=source_spec.entire_shape,
dim_partition_dict={})
comm_spec = {'src_spec': source_spec, 'tgt_spec': target_spec}
input_comm_action = CommAction(comm_spec=comm_spec, comm_type=CommType.BEFORE, arg_index=0)
else:
input_comm_action = None
if input_comm_action is not None:
communication_action_mapping["input"] = input_comm_action
strategy = self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
strategy_list.append(strategy)
return strategy_list
class PermuteGenerator(ReshapeGenerator):
"""
PermuteGenerator deals with the sharding strategies of permute op.
"""
def collate_strategies(self) -> List[ShardingStrategy]:
strategy_list = []
for index, strategy in enumerate(self.predecessor_node.strategies_vector):
dim_partition_dict_mapping = {}
communication_action_mapping = {}
input_sharding_spec = strategy.output_sharding_specs[self.op_data["input"]]
permute_dims = self.op_data['permute_dims'].data
dim_partition_dict_for_input = input_sharding_spec.dim_partition_dict
dim_partition_dict_for_output = {}
for dim_index, permute_dim in enumerate(permute_dims):
if permute_dim in dim_partition_dict_for_input:
dim_partition_dict_for_output[dim_index] = dim_partition_dict_for_input[permute_dim]
dim_partition_dict_mapping = {
"input": dim_partition_dict_for_input,
"output": dim_partition_dict_for_output,
}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
# add index into name to pass the duplicated check
# we keep same strategies with different name for node merging, and it will not increase the searching space,
# because in solver, this node will be merged into other nodes, and solver will not create a new variable for this node.
name = f'{sharding_spec_mapping["input"].sharding_sequence} -> {sharding_spec_mapping["output"].sharding_sequence}_{index}'
strategy = self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
strategy_list.append(strategy)
return strategy_list
class TransposeGenerator(ReshapeGenerator):
"""
TransposeGenerator deals with the sharding strategies of permute op.
"""
def collate_strategies(self) -> List[ShardingStrategy]:
strategy_list = []
for index, strategy in enumerate(self.predecessor_node.strategies_vector):
dim_partition_dict_mapping = {}
communication_action_mapping = {}
input_sharding_spec = strategy.output_sharding_specs[self.op_data["input"]]
dim_partition_dict_for_input = input_sharding_spec.dim_partition_dict
dim_partition_dict_for_output = {}
transpose_dims = self.op_data['transpose_dims'].data
dim_0 = transpose_dims[0]
dim_1 = transpose_dims[1]
for dim, sharded_dims in dim_partition_dict_for_input.items():
if dim == dim_0:
dim_partition_dict_for_output[dim_1] = dim_partition_dict_for_input[dim_0]
elif dim == dim_1:
dim_partition_dict_for_output[dim_0] = dim_partition_dict_for_input[dim_1]
else:
dim_partition_dict_for_output[dim] = sharded_dims
dim_partition_dict_mapping = {
"input": dim_partition_dict_for_input,
"output": dim_partition_dict_for_output,
}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
# add index into name to pass the duplicated check
# we keep same strategies with different name for node merging, and it will not increase the searching space,
# because in solver, this node will be merged into other nodes, and solver will not create a new variable for this node.
name = f'{sharding_spec_mapping["input"].sharding_sequence} -> {sharding_spec_mapping["output"].sharding_sequence}_{index}'
strategy = self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
strategy_list.append(strategy)
return strategy_list
class SplitGenerator(ReshapeGenerator):
"""
SplitGenerator deals with the sharding strategies of split op.
"""
def collate_strategies(self) -> List[ShardingStrategy]:
strategy_list = []
for index, strategy in enumerate(self.predecessor_node.strategies_vector):
recover_dims = None
dim_partition_dict_mapping = {}
communication_action_mapping = {}
input_sharding_spec = strategy.output_sharding_specs[self.op_data["input"]]
dim_partition_dict_for_input = copy.deepcopy(input_sharding_spec.dim_partition_dict)
split_size, split_dim = self.op_data['split_info'].data
if split_dim in dim_partition_dict_for_input:
recover_dims = dim_partition_dict_for_input.pop(split_dim)
dim_partition_dict_for_output = [
copy.deepcopy(dim_partition_dict_for_input) for _ in range(len(self.op_data["output"].data))
]
assert len(dim_partition_dict_for_output) >= 2
dim_partition_dict_mapping = {
"input": dim_partition_dict_for_input,
"output": dim_partition_dict_for_output,
}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
# add index into name to pass the duplicated check
# we keep same strategies with different name for node merging, and it will not increase the searching space,
# because in solver, this node will be merged into other nodes, and solver will not create a new variable for this node.
name = f'{sharding_spec_mapping["input"].sharding_sequence}_{index}'
# add comm action if the input need to be recovered to replica in the split dimension.
if recover_dims:
# if there is only one sharding dimension, we should use the value instead of list as logical_process_axis.
if len(recover_dims) == 1:
recover_dims = recover_dims[0]
input_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping["input"],
communication_pattern=CollectiveCommPattern.GATHER_FWD_SPLIT_BWD,
logical_process_axis=recover_dims,
comm_type=CommType.BEFORE,
arg_index=0)
# it will gather the input through gather_dim during forward phase.
input_comm_action.comm_spec.gather_dim = split_dim
# it will split the input activation grad through split_dim during backward phase.
input_comm_action.comm_spec.shard_dim = split_dim
elif len(recover_dims) >= 2:
# original sharding spec
source_spec = input_sharding_spec
# target sharding spec
target_spec = sharding_spec_mapping["input"]
comm_spec = {'src_spec': source_spec, 'tgt_spec': target_spec}
input_comm_action = CommAction(comm_spec=comm_spec, comm_type=CommType.BEFORE, arg_index=0)
else:
input_comm_action = None
if input_comm_action is not None:
communication_action_mapping["input"] = input_comm_action
strategy = self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
strategy_list.append(strategy)
return strategy_list

View File

@ -2,11 +2,10 @@ from typing import Dict, List
import torch
from ...sharding_strategy import OperationData, OperationDataType
from ..node_handler import NodeHandler
from ..registry import operator_registry
from ..strategy import StrategyGenerator
from .reshape_generator import PermuteGenerator
from ..sharding_strategy import OperationData, OperationDataType
from .node_handler import NodeHandler
from .registry import operator_registry
from .strategy import PermuteGenerator, StrategyGenerator
__all__ = ['PermuteHandler']

View File

@ -2,11 +2,10 @@ from typing import Dict, List
import torch
from ...sharding_strategy import OperationData, OperationDataType
from ..node_handler import NodeHandler
from ..registry import operator_registry
from ..strategy import StrategyGenerator
from .reshape_generator import SplitGenerator
from ..sharding_strategy import OperationData, OperationDataType
from .node_handler import NodeHandler
from .registry import operator_registry
from .strategy import SplitGenerator, StrategyGenerator
__all__ = ['SplitHandler']

View File

@ -14,7 +14,13 @@ from .matmul_strategy_generator import (
from .normal_pooling_generator import NormalPoolStrategyGenerator
from .output_generator import OutputGenerator
from .placeholder_generator import PlaceholderGenerator
from .reshape_generator import ReshapeGenerator
from .reshape_generator import (
DefaultReshapeGenerator,
PermuteGenerator,
SplitGenerator,
TransposeGenerator,
ViewGenerator,
)
from .softmax_generator import SoftmaxGenerator
from .strategy_generator import StrategyGenerator
from .sum_generator import SumGenerator
@ -26,7 +32,8 @@ __all__ = [
'StrategyGenerator', 'DotProductStrategyGenerator', 'MatVecStrategyGenerator', 'LinearProjectionStrategyGenerator',
'BatchedMatMulStrategyGenerator', 'ConvStrategyGenerator', 'UnaryElementwiseGenerator',
'BatchNormStrategyGenerator', 'GetItemStrategyGenerator', 'TensorStrategyGenerator', 'TensorTupleStrategyGenerator',
'LayerNormGenerator', 'ReshapeGenerator', 'PlaceholderGenerator', 'OutputGenerator', 'WhereGenerator',
'ReshapeGenerator', 'NormalPoolStrategyGenerator', 'BinaryElementwiseStrategyGenerator', 'GetattrGenerator',
'TensorConstructorGenerator', 'EmbeddingStrategyGenerator', 'SumGenerator', 'SoftmaxGenerator'
'LayerNormGenerator', 'PlaceholderGenerator', 'OutputGenerator', 'WhereGenerator', 'NormalPoolStrategyGenerator',
'BinaryElementwiseStrategyGenerator', 'GetattrGenerator', 'TensorConstructorGenerator',
'EmbeddingStrategyGenerator', 'SumGenerator', 'SoftmaxGenerator', 'ViewGenerator', 'PermuteGenerator',
'TransposeGenerator', 'SplitGenerator', 'DefaultReshapeGenerator'
]

View File

@ -1,6 +1,7 @@
import copy
from typing import List
from colossalai.auto_parallel.tensor_shard.node_handler.strategy.strategy_generator import FollowingStrategyGenerator
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
CommAction,
CommType,
@ -8,17 +9,20 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
ShardingStrategy,
TrainCycleItem,
)
from colossalai.auto_parallel.tensor_shard.utils import (
check_keep_sharding_status,
detect_reshape_mapping,
infer_output_dim_partition_dict,
)
from colossalai.tensor.shape_consistency import CollectiveCommPattern
from colossalai.tensor.sharding_spec import ShardingSpec
from .strategy_generator import FollowingStrategyGenerator
__all__ = ['ReshapeGenerator']
__all__ = ['ReshapeGenerator', 'ViewGenerator', 'PermuteGenerator', 'TransposeGenerator', 'SplitGenerator']
class ReshapeGenerator(FollowingStrategyGenerator):
"""
ReshapeGenerator which deals with the sharding strategies of Reshape Op, such as torch.Tensor.permute.
ReshapeGenerator is the base class for all the reshape operation.
"""
def validate(self) -> bool:
@ -57,11 +61,255 @@ class ReshapeGenerator(FollowingStrategyGenerator):
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
strategy.memory_cost = memory_cost
def collate_strategies(self) -> List[ShardingStrategy]:
return super().collate_strategies()
class ViewGenerator(ReshapeGenerator):
"""
ViewGenerator deals with the sharding strategies of view op.
"""
def collate_strategies(self) -> List[ShardingStrategy]:
strategy_list = []
# For reshape function, to keep the computing correctness we keep the sharding
# spec of input is fully replicated. In addition, we will keep the output in
# replica status and let the successor node choose the way to resharding the
for index, strategy in enumerate(self.predecessor_node.strategies_vector):
dim_partition_dict_mapping = {}
communication_action_mapping = {}
input_sharding_spec = strategy.output_sharding_specs[self.op_data["input"]]
origin_shape = self.op_data['input'].data.shape
tgt_shape = self.op_data['tgt_shape'].data
reshape_mapping_dict = detect_reshape_mapping(origin_shape, tgt_shape)
dim_partition_dict_for_input = input_sharding_spec.dim_partition_dict
keep_sharding_status = check_keep_sharding_status(dim_partition_dict_for_input, reshape_mapping_dict)
if keep_sharding_status:
dim_partition_dict_for_output = infer_output_dim_partition_dict(dim_partition_dict_for_input,
reshape_mapping_dict)
else:
dim_partition_dict_for_output = {}
dim_partition_dict_mapping = {
"input": dim_partition_dict_for_input,
"output": dim_partition_dict_for_output,
}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
# add index into name to pass the duplicated check
# we keep same strategies with different name for node merging, and it will not increase the searching space,
# because in solver, this node will be merged into other nodes, and solver will not create a new variable for this node.
if keep_sharding_status:
name = f'{sharding_spec_mapping["input"].sharding_sequence} -> {sharding_spec_mapping["output"].sharding_sequence}_{index}'
else:
name = f'{sharding_spec_mapping["input"].sharding_sequence} -> FULLY REPLICATED_{index}'
# add comm action for converting input to fully replicated
total_mesh_dim_list = []
for mesh_dim_list in dim_partition_dict_for_input.values():
total_mesh_dim_list.extend(mesh_dim_list)
# if there is only one sharding dimension, we should use the value instead of list as logical_process_axis.
if len(total_mesh_dim_list) == 1:
total_mesh_dim_list = total_mesh_dim_list[0]
# the total mesh dim list only has one element, so the shard dim has only one element as well.
shard_dim = list(dim_partition_dict_for_input.keys())[0]
input_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping["input"],
communication_pattern=CollectiveCommPattern.GATHER_FWD_SPLIT_BWD,
logical_process_axis=total_mesh_dim_list,
comm_type=CommType.BEFORE,
arg_index=0)
# it will gather the input through gather_dim during forward phase.
input_comm_action.comm_spec.gather_dim = shard_dim
# it will split the input activation grad through shard_dim during backward phase.
input_comm_action.comm_spec.shard_dim = shard_dim
elif len(total_mesh_dim_list) >= 2:
source_spec = sharding_spec_mapping["input"]
target_spec = ShardingSpec(device_mesh=self.device_mesh,
entire_shape=source_spec.entire_shape,
dim_partition_dict={})
comm_spec = {'src_spec': source_spec, 'tgt_spec': target_spec}
input_comm_action = CommAction(comm_spec=comm_spec, comm_type=CommType.BEFORE, arg_index=0)
else:
input_comm_action = None
if input_comm_action is not None:
communication_action_mapping["input"] = input_comm_action
strategy = self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
strategy_list.append(strategy)
return strategy_list
class PermuteGenerator(ReshapeGenerator):
"""
PermuteGenerator deals with the sharding strategies of permute op.
"""
def collate_strategies(self) -> List[ShardingStrategy]:
strategy_list = []
for index, strategy in enumerate(self.predecessor_node.strategies_vector):
dim_partition_dict_mapping = {}
communication_action_mapping = {}
input_sharding_spec = strategy.output_sharding_specs[self.op_data["input"]]
permute_dims = self.op_data['permute_dims'].data
dim_partition_dict_for_input = input_sharding_spec.dim_partition_dict
dim_partition_dict_for_output = {}
for dim_index, permute_dim in enumerate(permute_dims):
if permute_dim in dim_partition_dict_for_input:
dim_partition_dict_for_output[dim_index] = dim_partition_dict_for_input[permute_dim]
dim_partition_dict_mapping = {
"input": dim_partition_dict_for_input,
"output": dim_partition_dict_for_output,
}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
# add index into name to pass the duplicated check
# we keep same strategies with different name for node merging, and it will not increase the searching space,
# because in solver, this node will be merged into other nodes, and solver will not create a new variable for this node.
name = f'{sharding_spec_mapping["input"].sharding_sequence} -> {sharding_spec_mapping["output"].sharding_sequence}_{index}'
strategy = self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
strategy_list.append(strategy)
return strategy_list
class TransposeGenerator(ReshapeGenerator):
"""
TransposeGenerator deals with the sharding strategies of permute op.
"""
def collate_strategies(self) -> List[ShardingStrategy]:
strategy_list = []
for index, strategy in enumerate(self.predecessor_node.strategies_vector):
dim_partition_dict_mapping = {}
communication_action_mapping = {}
input_sharding_spec = strategy.output_sharding_specs[self.op_data["input"]]
dim_partition_dict_for_input = input_sharding_spec.dim_partition_dict
dim_partition_dict_for_output = {}
transpose_dims = self.op_data['transpose_dims'].data
dim_0 = transpose_dims[0]
dim_1 = transpose_dims[1]
for dim, sharded_dims in dim_partition_dict_for_input.items():
if dim == dim_0:
dim_partition_dict_for_output[dim_1] = dim_partition_dict_for_input[dim_0]
elif dim == dim_1:
dim_partition_dict_for_output[dim_0] = dim_partition_dict_for_input[dim_1]
else:
dim_partition_dict_for_output[dim] = sharded_dims
dim_partition_dict_mapping = {
"input": dim_partition_dict_for_input,
"output": dim_partition_dict_for_output,
}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
# add index into name to pass the duplicated check
# we keep same strategies with different name for node merging, and it will not increase the searching space,
# because in solver, this node will be merged into other nodes, and solver will not create a new variable for this node.
name = f'{sharding_spec_mapping["input"].sharding_sequence} -> {sharding_spec_mapping["output"].sharding_sequence}_{index}'
strategy = self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
strategy_list.append(strategy)
return strategy_list
class SplitGenerator(ReshapeGenerator):
"""
SplitGenerator deals with the sharding strategies of split op.
"""
def collate_strategies(self) -> List[ShardingStrategy]:
strategy_list = []
for index, strategy in enumerate(self.predecessor_node.strategies_vector):
recover_dims = None
dim_partition_dict_mapping = {}
communication_action_mapping = {}
input_sharding_spec = strategy.output_sharding_specs[self.op_data["input"]]
dim_partition_dict_for_input = copy.deepcopy(input_sharding_spec.dim_partition_dict)
split_size, split_dim = self.op_data['split_info'].data
if split_dim in dim_partition_dict_for_input:
recover_dims = dim_partition_dict_for_input.pop(split_dim)
dim_partition_dict_for_output = [
copy.deepcopy(dim_partition_dict_for_input) for _ in range(len(self.op_data["output"].data))
]
assert len(dim_partition_dict_for_output) >= 2
dim_partition_dict_mapping = {
"input": dim_partition_dict_for_input,
"output": dim_partition_dict_for_output,
}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
# add index into name to pass the duplicated check
# we keep same strategies with different name for node merging, and it will not increase the searching space,
# because in solver, this node will be merged into other nodes, and solver will not create a new variable for this node.
name = f'{sharding_spec_mapping["input"].sharding_sequence}_{index}'
# add comm action if the input need to be recovered to replica in the split dimension.
if recover_dims:
# if there is only one sharding dimension, we should use the value instead of list as logical_process_axis.
if len(recover_dims) == 1:
recover_dims = recover_dims[0]
input_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping["input"],
communication_pattern=CollectiveCommPattern.GATHER_FWD_SPLIT_BWD,
logical_process_axis=recover_dims,
comm_type=CommType.BEFORE,
arg_index=0)
# it will gather the input through gather_dim during forward phase.
input_comm_action.comm_spec.gather_dim = split_dim
# it will split the input activation grad through split_dim during backward phase.
input_comm_action.comm_spec.shard_dim = split_dim
elif len(recover_dims) >= 2:
# original sharding spec
source_spec = input_sharding_spec
# target sharding spec
target_spec = sharding_spec_mapping["input"]
comm_spec = {'src_spec': source_spec, 'tgt_spec': target_spec}
input_comm_action = CommAction(comm_spec=comm_spec, comm_type=CommType.BEFORE, arg_index=0)
else:
input_comm_action = None
if input_comm_action is not None:
communication_action_mapping["input"] = input_comm_action
strategy = self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
strategy_list.append(strategy)
return strategy_list
class DefaultReshapeGenerator(ReshapeGenerator):
"""
DefaultReshapeGenerator which deals with the sharding strategies of Reshape Op which have to recover the tensor
to Replica status.
"""
def collate_strategies(self) -> List[ShardingStrategy]:
strategy_list = []
# For default reshape strategy, to keep the computing correctness we keep the
# sharding spec of input is fully replicated. In addition, we will keep the output
# in replica status and let the successor node choose the way to resharding the
# output node. Therefore, the different strategies of input node with same
# output sharding spec will generate same strategy for reshape function.
for index, strategy in enumerate(self.predecessor_node.strategies_vector):
@ -114,9 +362,4 @@ class ReshapeGenerator(FollowingStrategyGenerator):
communication_action_mapping=communication_action_mapping)
strategy_list.append(strategy)
for strategy in strategy_list:
self.update_communication_cost(strategy)
self.update_compute_cost(strategy)
self.update_memory_cost(strategy)
return strategy_list

View File

@ -2,11 +2,10 @@ from typing import Dict, List
import torch
from ...sharding_strategy import OperationData, OperationDataType
from ..node_handler import NodeHandler
from ..registry import operator_registry
from ..strategy import StrategyGenerator
from .reshape_generator import TransposeGenerator
from ..sharding_strategy import OperationData, OperationDataType
from .node_handler import NodeHandler
from .registry import operator_registry
from .strategy import StrategyGenerator, TransposeGenerator
__all__ = ['TransposeHandler']

View File

@ -2,11 +2,10 @@ from typing import Dict, List
import torch
from ...sharding_strategy import OperationData, OperationDataType
from ..node_handler import NodeHandler
from ..registry import operator_registry
from ..strategy import StrategyGenerator
from .reshape_generator import ViewGenerator
from ..sharding_strategy import OperationData, OperationDataType
from .node_handler import NodeHandler
from .registry import operator_registry
from .strategy import StrategyGenerator, ViewGenerator
__all__ = ['ViewHandler']

View File

@ -1,8 +1,8 @@
import torch
import torch.nn as nn
from colossalai.auto_parallel.tensor_shard.node_handler import DefaultReshapeHandler
from colossalai.auto_parallel.tensor_shard.node_handler.conv_handler import ConvFunctionHandler
from colossalai.auto_parallel.tensor_shard.node_handler.reshape_handler import ReshapeHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
@ -51,7 +51,7 @@ def test_reshape_handler():
strategies_vector=conv_strategies_vector)
conv_handler.register_strategy(compute_resharding_cost=False)
setattr(conv_mod_node, 'strategies_vector', conv_strategies_vector)
reshape_handler = ReshapeHandler(node=reshape_node,
reshape_handler = DefaultReshapeHandler(node=reshape_node,
device_mesh=device_mesh,
strategies_vector=reshape_strategies_vector)

View File

@ -5,10 +5,10 @@ import torch
import torch.multiprocessing as mp
import torch.nn as nn
from colossalai.auto_parallel.tensor_shard.node_handler.default_reshape_handler import DefaultReshapeHandler
from colossalai.auto_parallel.tensor_shard.node_handler.getitem_handler import GetItemHandler
from colossalai.auto_parallel.tensor_shard.node_handler.linear_handler import LinearFunctionHandler
from colossalai.auto_parallel.tensor_shard.node_handler.placeholder_handler import PlaceholderHandler
from colossalai.auto_parallel.tensor_shard.node_handler.reshape_handler import ReshapeHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
@ -153,7 +153,9 @@ def test_getitem_from_tuple_handler():
)
input_handler.register_strategy(compute_resharding_cost=False)
setattr(input_node, 'strategies_vector', input_strategies_vector)
split_handler = ReshapeHandler(node=split_node, device_mesh=device_mesh, strategies_vector=split_strategies_vector)
split_handler = DefaultReshapeHandler(node=split_node,
device_mesh=device_mesh,
strategies_vector=split_strategies_vector)
split_handler.register_strategy(compute_resharding_cost=False)
setattr(split_node, 'strategies_vector', split_strategies_vector)
getitem_handler = GetItemHandler(node=getitem_node,

View File

@ -5,8 +5,8 @@ import torch
import torch.multiprocessing as mp
import torch.nn as nn
from colossalai.auto_parallel.tensor_shard.node_handler import PermuteHandler, TransposeHandler
from colossalai.auto_parallel.tensor_shard.node_handler.conv_handler import ConvFunctionHandler
from colossalai.auto_parallel.tensor_shard.node_handler.experimental import PermuteHandler, TransposeHandler
from colossalai.auto_parallel.tensor_shard.node_handler.linear_handler import LinearFunctionHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
from colossalai.device.device_mesh import DeviceMesh

View File

@ -5,8 +5,8 @@ import torch
import torch.multiprocessing as mp
import torch.nn as nn
from colossalai.auto_parallel.tensor_shard.node_handler import SplitHandler
from colossalai.auto_parallel.tensor_shard.node_handler.conv_handler import ConvFunctionHandler
from colossalai.auto_parallel.tensor_shard.node_handler.experimental import SplitHandler
from colossalai.auto_parallel.tensor_shard.node_handler.linear_handler import LinearFunctionHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
from colossalai.device.device_mesh import DeviceMesh
@ -156,8 +156,7 @@ def check_split_handler(rank, split_size, split_dim, model_cls, world_size, port
# reshape handler is a following strategy handler, so the number of strategies is equal to the predecessor node.
assert len(split_strategies_vector) == len(previous_strategies_vector)
strategy_name_list = [strategy.name for strategy in split_strategies_vector]
for name in strategy_name_list:
print(name)
if model_cls.__name__ == 'ConvSplitModel':
if split_dim == 0:

View File

@ -5,8 +5,8 @@ import torch
import torch.multiprocessing as mp
import torch.nn as nn
from colossalai.auto_parallel.tensor_shard.node_handler import ViewHandler
from colossalai.auto_parallel.tensor_shard.node_handler.conv_handler import ConvFunctionHandler
from colossalai.auto_parallel.tensor_shard.node_handler.experimental import ViewHandler
from colossalai.auto_parallel.tensor_shard.node_handler.linear_handler import LinearFunctionHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
from colossalai.device.device_mesh import DeviceMesh