diff --git a/colossalai/auto_parallel/solver/op_handler/reshape_handler_v2.py b/colossalai/auto_parallel/solver/op_handler/reshape_handler_v2.py new file mode 100644 index 000000000..977a4c94a --- /dev/null +++ b/colossalai/auto_parallel/solver/op_handler/reshape_handler_v2.py @@ -0,0 +1,35 @@ +import torch +from .node_handler import NodeHandler +from ..sharding_strategy import ShardingStrategy_V2, OperationDataType, OperationData, StrategiesVector +from ..strategy import ReshapeGenerator, StrategyGenerator_V2 +from typing import List, Dict +from .registry import operator_registry +import operator + +__all__ = ['ReshapeHandler'] + + +@operator_registry.register(torch.reshape) +@operator_registry.register(torch.Tensor.permute) +class ReshapeHandler(NodeHandler): + """ + A ReshapeHandler which deals with the sharding strategies for Reshape Op, such as torch.reshape. + """ + + def get_strategy_generator(self) -> List[StrategyGenerator_V2]: + op_data_mapping = self.get_operation_data_mapping() + generators = [] + generators.append(ReshapeGenerator(op_data_mapping, self.device_mesh, self.node.args[0])) + return generators + + def get_operation_data_mapping(self) -> Dict[str, OperationData]: + # use transposed shape for strategies + # the strategies will be transformed back to its original shape in self.post_process + physical_input_operand = OperationData(name=str(self.node.args[0]), + type=OperationDataType.ARG, + data=self.node.args[0]._meta_data) + physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data) + + mapping = {"input": physical_input_operand, "output": physical_output} + + return mapping diff --git a/colossalai/auto_parallel/solver/strategy/__init__.py b/colossalai/auto_parallel/solver/strategy/__init__.py index ae6249205..b65da3f16 100644 --- a/colossalai/auto_parallel/solver/strategy/__init__.py +++ b/colossalai/auto_parallel/solver/strategy/__init__.py @@ -5,10 +5,11 @@ from .batch_norm_generator import BatchNormStrategyGenerator from .unary_elementwise_generator import UnaryElementwiseGenerator from .getitem_generator import GetItemStrategyGenerator, TensorStrategyGenerator, TensorTupleStrategyGenerator from .layer_norm_generator import LayerNormGenerator +from .reshape_generator import ReshapeGenerator __all__ = [ 'StrategyGenerator_V2', 'DotProductStrategyGenerator', 'MatVecStrategyGenerator', - 'LinearProjectionStrategyGenerator', 'BatchedMatMulStrategyGenerator', 'ConvStrategyGenerator', 'UnaryElementwiseGenerator', - 'BatchNormStrategyGenerator', 'GetItemStrategyGenerator', 'TensorStrategyGenerator', 'TensorTupleStrategyGenerator', - 'LayerNormGenerator' + 'LinearProjectionStrategyGenerator', 'BatchedMatMulStrategyGenerator', 'ConvStrategyGenerator', + 'UnaryElementwiseGenerator', 'BatchNormStrategyGenerator', 'GetItemStrategyGenerator', 'TensorStrategyGenerator', + 'TensorTupleStrategyGenerator', 'LayerNormGenerator', 'ReshapeGenerator' ] diff --git a/colossalai/auto_parallel/solver/strategy/batch_norm_generator.py b/colossalai/auto_parallel/solver/strategy/batch_norm_generator.py index a89517004..3e7302c27 100644 --- a/colossalai/auto_parallel/solver/strategy/batch_norm_generator.py +++ b/colossalai/auto_parallel/solver/strategy/batch_norm_generator.py @@ -37,7 +37,7 @@ class BatchNormStrategyGenerator(StrategyGenerator_V2): assert input_op_data.dim() in (3, 4, 5), f'We suppose the dim of input fed into conv op should in range of [3, 5].' - def update_compute_cost(self, strategy: ShardingStrategy_V2) -> TrainCycleItem: + def update_compute_cost(self, strategy: ShardingStrategy_V2): ''' Compute the computation cost per device with this specific strategy. @@ -62,9 +62,9 @@ class BatchNormStrategyGenerator(StrategyGenerator_V2): backward_compute_cost += bias_compute_cost total_compute_cost = forward_compute_cost + backward_compute_cost compute_cost = TrainCycleItem(fwd=forward_compute_cost, bwd=backward_compute_cost, total=total_compute_cost) - return compute_cost + strategy.compute_cost = compute_cost - def update_memory_cost(self, strategy: ShardingStrategy_V2) -> TrainCycleItem: + def update_memory_cost(self, strategy: ShardingStrategy_V2): forward_size_mapping = { 'input': self._compute_size_in_bytes(strategy, "input"), 'other': self._compute_size_in_bytes(strategy, "other"), diff --git a/colossalai/auto_parallel/solver/strategy/conv_strategy_generator.py b/colossalai/auto_parallel/solver/strategy/conv_strategy_generator.py index ef989f92c..a599aca66 100644 --- a/colossalai/auto_parallel/solver/strategy/conv_strategy_generator.py +++ b/colossalai/auto_parallel/solver/strategy/conv_strategy_generator.py @@ -29,7 +29,7 @@ class ConvStrategyGenerator(StrategyGenerator_V2): assert input_op_data.dim() in (3, 4, 5), f'We suppose the dim of input fed into conv op should in range of [3, 5].' - def update_compute_cost(self, strategy: ShardingStrategy_V2) -> TrainCycleItem: + def update_compute_cost(self, strategy: ShardingStrategy_V2): ''' Compute the computation cost per device with this specific strategy. @@ -67,9 +67,9 @@ class ConvStrategyGenerator(StrategyGenerator_V2): total_compute_cost = forward_compute_cost + backward_compute_cost compute_cost = TrainCycleItem(fwd=forward_compute_cost, bwd=backward_compute_cost, total=total_compute_cost) - return compute_cost + strategy.compute_cost = compute_cost - def update_memory_cost(self, strategy: ShardingStrategy_V2) -> ShardingStrategy_V2: + def update_memory_cost(self, strategy: ShardingStrategy_V2): forward_size_mapping = { 'input': self._compute_size_in_bytes(strategy, "input"), 'other': self._compute_size_in_bytes(strategy, "other"), diff --git a/colossalai/auto_parallel/solver/strategy/getitem_generator.py b/colossalai/auto_parallel/solver/strategy/getitem_generator.py index 43f2eb550..0e1287eae 100644 --- a/colossalai/auto_parallel/solver/strategy/getitem_generator.py +++ b/colossalai/auto_parallel/solver/strategy/getitem_generator.py @@ -28,10 +28,11 @@ class GetItemStrategyGenerator(FollowingStrategyGenerator): def validate(self) -> bool: return super().validate() - def update_compute_cost(self, strategy: ShardingStrategy_V2) -> TrainCycleItem: - return TrainCycleItem(fwd=10, bwd=10, total=20) + def update_compute_cost(self, strategy: ShardingStrategy_V2): + compute_cost = TrainCycleItem(fwd=10, bwd=10, total=20) + strategy.compute_cost = compute_cost - def update_memory_cost(self, strategy: ShardingStrategy_V2) -> TrainCycleItem: + def update_memory_cost(self, strategy: ShardingStrategy_V2): ''' Compute the memory cost per device with this specific strategy. ''' @@ -59,7 +60,6 @@ class GetItemStrategyGenerator(FollowingStrategyGenerator): parameter=fwd_parameter_cost + bwd_parameter_cost) memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost) strategy.memory_cost = memory_cost - return super().update_memory_cost(strategy) class TensorStrategyGenerator(GetItemStrategyGenerator): diff --git a/colossalai/auto_parallel/solver/strategy/layer_norm_generator.py b/colossalai/auto_parallel/solver/strategy/layer_norm_generator.py index d20a7d821..97130a0b9 100644 --- a/colossalai/auto_parallel/solver/strategy/layer_norm_generator.py +++ b/colossalai/auto_parallel/solver/strategy/layer_norm_generator.py @@ -23,7 +23,7 @@ class LayerNormGenerator(StrategyGenerator_V2): def validate(self) -> bool: return super().validate() - def update_compute_cost(self, strategy: ShardingStrategy_V2) -> TrainCycleItem: + def update_compute_cost(self, strategy: ShardingStrategy_V2): ''' Compute the computation cost per device with this specific strategy. @@ -52,9 +52,9 @@ class LayerNormGenerator(StrategyGenerator_V2): backward_compute_cost += bias_compute_cost total_compute_cost = forward_compute_cost + backward_compute_cost compute_cost = TrainCycleItem(fwd=forward_compute_cost, bwd=backward_compute_cost, total=total_compute_cost) - return compute_cost + strategy.compute_cost = compute_cost - def update_memory_cost(self, strategy: ShardingStrategy_V2) -> TrainCycleItem: + def update_memory_cost(self, strategy: ShardingStrategy_V2): ''' Compute the memory cost per device with this specific strategy. ''' @@ -103,6 +103,9 @@ class LayerNormGenerator(StrategyGenerator_V2): total_mesh_dim_list = [] for mesh_dim_list in dim_partition.values(): total_mesh_dim_list.extend(mesh_dim_list) + # if there is only one sharding dimension, we should use the value instead of list as logical_process_axis. + if len(total_mesh_dim_list) == 1: + total_mesh_dim_list = total_mesh_dim_list[0] communication_action_mapping = {} other_comm_spec = self.get_communication_spec( diff --git a/colossalai/auto_parallel/solver/strategy/reshape_generator.py b/colossalai/auto_parallel/solver/strategy/reshape_generator.py new file mode 100644 index 000000000..401764aed --- /dev/null +++ b/colossalai/auto_parallel/solver/strategy/reshape_generator.py @@ -0,0 +1,100 @@ +import operator +from functools import reduce +from ..sharding_strategy import ShardingStrategy_V2, TrainCycleItem, MemoryCost +from colossalai.tensor.shape_consistency import CollectiveCommPattern +from .strategy_generator import FollowingStrategyGenerator +from typing import List +import copy + +__all__ = ['ReshapeGenerator'] + + +class ReshapeGenerator(FollowingStrategyGenerator): + """ + ReshapeGenerator which deals with the sharding strategies of Reshape Op, such as torch.Tensor.permute. + """ + + def validate(self) -> bool: + return super().validate() + + def update_compute_cost(self, strategy: ShardingStrategy_V2): + compute_cost = TrainCycleItem(fwd=10, bwd=10, total=20) + strategy.compute_cost = compute_cost + + def update_memory_cost(self, strategy: ShardingStrategy_V2): + ''' + Compute the memory cost per device with this specific strategy. + ''' + forward_size_mapping = { + 'input': self._compute_size_in_bytes(strategy, "input"), + 'output': self._compute_size_in_bytes(strategy, "output") + } + + backward_size_mapping = copy.deepcopy(forward_size_mapping) + backward_size_mapping.pop("output") + # compute fwd cost incurred + # fwd_cost = input + output + fwd_activation_cost = sum([v for k, v in forward_size_mapping.items() if not self.is_param(k)]) + fwd_parameter_cost = sum([v for k, v in forward_size_mapping.items() if self.is_param(k)]) + fwd_mem_cost = MemoryCost(activation=fwd_activation_cost, parameter=fwd_parameter_cost) + + # compute bwd cost incurred + # bwd_cost = input_grad + bwd_activation_cost = sum([v for k, v in backward_size_mapping.items() if not self.is_param(k)]) + bwd_parameter_cost = sum([v for k, v in backward_size_mapping.items() if self.is_param(k)]) + bwd_mem_cost = MemoryCost(activation=bwd_activation_cost, parameter=bwd_parameter_cost) + + # compute total cost + total_mem_cost = MemoryCost(activation=fwd_activation_cost + bwd_activation_cost, + parameter=fwd_parameter_cost + bwd_parameter_cost) + memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost) + strategy.memory_cost = memory_cost + + def generate(self): + strategy_list = [] + # For reshape function, to keep the computing correctness we keep the sharding + # spec of input is fully replicated. In addition, we will keep the output in + # replica status and let the successor node choose the way to resharding the + # output node. Therefore, the different strategies of input node with same + # output sharding spec will generate same strategy for reshape function. + for index, strategy in enumerate(self.predecessor_node.strategies_vector): + dim_partition_dict_mapping = {} + communication_action_mapping = {} + input_sharding_spec = strategy.output_sharding_specs[self.op_data["input"]] + dim_partition_dict_for_input = input_sharding_spec.dim_partition_dict + dim_partition_dict_for_output = {} + if isinstance(self.op_data["output"].data, tuple): + dim_partition_dict_for_output = [{} for _ in range(len(self.op_data["output"].data))] + dim_partition_dict_mapping = { + "input": dim_partition_dict_for_input, + "output": dim_partition_dict_for_output, + } + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) + # add index into name to pass the duplicated check + # we keep same strategies with different name for node merging, and it will not increase the searching space, + # because in solver, this node will be merged into other nodes, and solver will not create a new variable for this node. + name = f'{sharding_spec_mapping["input"].sharding_sequence} -> FULLY REPLICATED_{index}' + + total_mesh_dim_list = [] + for mesh_dim_list in dim_partition_dict_for_input.values(): + total_mesh_dim_list.extend(mesh_dim_list) + # if there is only one sharding dimension, we should use the value instead of list as logical_process_axis. + if len(total_mesh_dim_list) == 1: + total_mesh_dim_list = total_mesh_dim_list[0] + + input_comm_spec = self.get_communication_spec( + sharding_spec=sharding_spec_mapping["input"], + communication_pattern=CollectiveCommPattern.GATHER_FWD_SPLIT_BWD, + logical_process_axis=total_mesh_dim_list) + communication_action_mapping["input"] = input_comm_spec + strategy = self.get_sharding_strategy(name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping) + strategy_list.append(strategy) + + for strategy in strategy_list: + self.update_communication_cost(strategy) + self.update_compute_cost(strategy) + self.update_memory_cost(strategy) + + return strategy_list diff --git a/colossalai/auto_parallel/solver/strategy/strategy_generator.py b/colossalai/auto_parallel/solver/strategy/strategy_generator.py index d44d86ad3..5bf3a8327 100644 --- a/colossalai/auto_parallel/solver/strategy/strategy_generator.py +++ b/colossalai/auto_parallel/solver/strategy/strategy_generator.py @@ -53,9 +53,16 @@ class StrategyGenerator_V2(ABC): for op_data_name, dim_partition_dict in mapping.items(): if op_data_name in self.op_data: op_data = self.op_data[op_data_name] - sharding_spec = ShardingSpec(device_mesh=self.device_mesh, - entire_shape=op_data.logical_shape, - dim_partition_dict=dim_partition_dict) + if isinstance(op_data.data, tuple) and isinstance(op_data.data[0], torch.Tensor): + sharding_spec = [] + for output, dim_partition_dict_element in zip(op_data.data, dim_partition_dict): + sharding_spec = ShardingSpec(device_mesh=self.device_mesh, + entire_shape=output.shape, + dim_partition_dict=dim_partition_dict_element) + else: + sharding_spec = ShardingSpec(device_mesh=self.device_mesh, + entire_shape=op_data.logical_shape, + dim_partition_dict=dim_partition_dict) results[op_data_name] = sharding_spec return results diff --git a/colossalai/auto_parallel/solver/strategy/unary_elementwise_generator.py b/colossalai/auto_parallel/solver/strategy/unary_elementwise_generator.py index c00c6b304..99db359e4 100644 --- a/colossalai/auto_parallel/solver/strategy/unary_elementwise_generator.py +++ b/colossalai/auto_parallel/solver/strategy/unary_elementwise_generator.py @@ -18,10 +18,11 @@ class UnaryElementwiseGenerator(FollowingStrategyGenerator): def validate(self) -> bool: return super().validate() - def update_compute_cost(self, strategy: ShardingStrategy_V2) -> TrainCycleItem: - return TrainCycleItem(fwd=10, bwd=10, total=20) + def update_compute_cost(self, strategy: ShardingStrategy_V2): + compute_cost = TrainCycleItem(fwd=10, bwd=10, total=20) + strategy.compute_cost = compute_cost - def update_memory_cost(self, strategy: ShardingStrategy_V2) -> TrainCycleItem: + def update_memory_cost(self, strategy: ShardingStrategy_V2): ''' Compute the memory cost per device with this specific strategy. ''' @@ -49,7 +50,6 @@ class UnaryElementwiseGenerator(FollowingStrategyGenerator): parameter=fwd_parameter_cost + bwd_parameter_cost) memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost) strategy.memory_cost = memory_cost - return super().update_memory_cost(strategy) def generate(self): strategy_list = [] diff --git a/tests/test_auto_parallel/test_node_handler/test_reshape_handler_v2.py b/tests/test_auto_parallel/test_node_handler/test_reshape_handler_v2.py new file mode 100644 index 000000000..e9a77b65b --- /dev/null +++ b/tests/test_auto_parallel/test_node_handler/test_reshape_handler_v2.py @@ -0,0 +1,81 @@ +import torch +import torch.nn as nn +from colossalai.fx import ColoTracer, ColoGraphModule +from colossalai.auto_parallel.solver.op_handler.conv_handler_v2 import ConvFunctionHandler +from colossalai.auto_parallel.solver.op_handler.reshape_handler_v2 import ReshapeHandler +from colossalai.auto_parallel.solver.sharding_strategy import OperationData, OperationDataType, StrategiesVector +from colossalai.device.device_mesh import DeviceMesh + + +class ReshapeModel(nn.Module): + + def __init__(self): + super().__init__() + + def forward(self, input, other): + conv_node = nn.functional.conv2d(input, other) + reshape_node = conv_node.view(2, -1) + return reshape_node + + +def test_reshape_handler(): + model = ReshapeModel() + tracer = ColoTracer() + # graph(): + # %input_1 : torch.Tensor [#users=1] = placeholder[target=input] + # %other : torch.Tensor [#users=1] = placeholder[target=other] + # %conv2d : [#users=1] = call_function[target=torch.conv2d](args = (%input_1, %other), kwargs = {}) + # %view : [#users=1] = call_method[target=view](args = (%conv2d, 2, -1), kwargs = {}) + # return view + graph = tracer.trace(model, + meta_args={ + "input": torch.rand(4, 4, 64, 64).to('meta'), + "other": torch.rand(4, 16, 3, 3).to('meta'), + }) + gm = ColoGraphModule(model, graph) + physical_mesh_id = torch.arange(0, 4) + + mesh_shape = (2, 2) + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) + conv_mod_node = list(graph.nodes)[2] + reshape_node = list(graph.nodes)[3] + reshape_strategies_vector = StrategiesVector(reshape_node) + conv_strategies_vector = StrategiesVector(conv_mod_node) + + # build handler + conv_handler = ConvFunctionHandler(node=conv_mod_node, + device_mesh=device_mesh, + strategies_vector=conv_strategies_vector) + conv_handler.register_strategy() + setattr(conv_mod_node, 'strategies_vector', conv_strategies_vector) + reshape_handler = ReshapeHandler(node=reshape_node, + device_mesh=device_mesh, + strategies_vector=reshape_strategies_vector) + + reshape_handler.register_strategy() + + # check operation data mapping + mapping = reshape_handler.get_operation_data_mapping() + + for name, op_data in mapping.items(): + op_data: OperationData + # make sure they have valid values + assert op_data.data is not None + + assert mapping['input'].name == "conv2d" + assert mapping['input'].data.is_meta + assert mapping['input'].data.shape == torch.Size([4, 4, 62, 62]) + assert mapping['input'].type == OperationDataType.ARG + assert mapping['input'].logical_shape == torch.Size([4, 4, 62, 62]) + + assert mapping['output'].name == "view" + assert mapping['output'].data.is_meta + assert mapping['output'].data.shape == torch.Size([2, 30752]) + assert mapping['output'].type == OperationDataType.OUTPUT + + # reshape handler is a following strategy handler, so the number of strategies is equal to the predecessor node. + assert len(reshape_strategies_vector) == len(conv_strategies_vector) + + +if __name__ == '__main__': + test_reshape_handler()