From 05020e50d076852943e4a4f1b30d30a197dd71e3 Mon Sep 17 00:00:00 2001 From: YuliangLiu0306 <72588413+YuliangLiu0306@users.noreply.github.com> Date: Fri, 18 Nov 2022 17:01:06 +0800 Subject: [PATCH] [autoparallel] support more flexible data type (#1967) --- .../tensor_shard/node_handler/__init__.py | 4 +- .../tensor_shard/node_handler/node_handler.py | 4 ++ .../node_handler/reshape_handler.py | 33 +++++++++++++- .../strategy/strategy_generator.py | 45 ++++++++++++++----- .../node_handler/unary_elementwise_handler.py | 2 + 5 files changed, 74 insertions(+), 14 deletions(-) diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/__init__.py b/colossalai/auto_parallel/tensor_shard/node_handler/__init__.py index 20d9d7c38..ab0063dd1 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/__init__.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/__init__.py @@ -4,6 +4,7 @@ from .binary_elementwise_handler import BinaryElementwiseHandler from .bmm_handler import AddBMMFunctionHandler, BMMFunctionHandler from .conv_handler import ConvFunctionHandler, ConvModuleHandler from .getatrr_handler import GetattrHandler +from .getitem_handler import GetItemHandler from .layer_norm_handler import LayerNormModuleHandler from .linear_handler import LinearFunctionHandler, LinearModuleHandler from .matmul_handler import MatMulHandler @@ -19,5 +20,6 @@ __all__ = [ 'LinearFunctionHandler', 'LinearModuleHandler', 'BMMFunctionHandler', 'AddBMMFunctionHandler', 'LayerNormModuleHandler', 'BatchNormModuleHandler', 'ConvModuleHandler', 'ConvFunctionHandler', 'UnaryElementwiseHandler', 'ReshapeHandler', 'PlacehodlerHandler', 'OuputHandler', 'WhereHandler', - 'NormPoolingHandler', 'BinaryElementwiseHandler', 'MatMulHandler', 'operator_registry', 'ADDMMFunctionHandler', 'GetattrHandler' + 'NormPoolingHandler', 'BinaryElementwiseHandler', 'MatMulHandler', 'operator_registry', 'ADDMMFunctionHandler', + 'GetItemHandler', 'GetattrHandler' ] diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/node_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/node_handler.py index 2d882fc09..826225a62 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/node_handler.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/node_handler.py @@ -51,6 +51,10 @@ class NodeHandler(ABC): for node in self.predecessor_node: node_name = str(node) # get the current sharding spec generated by this node handler + + # TODO: we need to check this in future + if not isinstance(node._meta_data, torch.Tensor): + continue op_data = strategy.get_op_data_by_name(node_name) current_sharding_spec = strategy.sharding_specs[op_data] diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/reshape_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/reshape_handler.py index d6a06bc15..3c232f131 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/reshape_handler.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/reshape_handler.py @@ -11,7 +11,9 @@ __all__ = ['ReshapeHandler'] @operator_registry.register(torch.reshape) +@operator_registry.register(torch.Tensor.split) @operator_registry.register(torch.flatten) +@operator_registry.register(torch.Tensor.transpose) @operator_registry.register(torch.Tensor.permute) @operator_registry.register(torch.Tensor.view) @operator_registry.register(torch.nn.AdaptiveAvgPool2d) @@ -26,6 +28,24 @@ class ReshapeHandler(NodeHandler): generators.append(ReshapeGenerator(op_data_mapping, self.device_mesh, self.node.args[0])) return generators + def infer_logical_shape(self, data): + """ + This function is used to infer logical shape for operands. + + Notes: This function is only used for the operands whose data are not only in type of tensor, + such as tuple of tensor. + """ + if isinstance(data, torch.Tensor): + return data.shape + else: + assert isinstance(data, tuple), "input_data should be a tuple of tensor or a tensor." + logical_shape = [] + for tensor in data: + assert isinstance(tensor, torch.Tensor), "input_data should be a tuple of tensor or a tensor." + logical_shape.append(tensor.shape) + logical_shape = tuple(logical_shape) + return logical_shape + def get_operation_data_mapping(self) -> Dict[str, OperationData]: # use transposed shape for strategies # the strategies will be transformed back to its original shape in self.post_process @@ -36,10 +56,19 @@ class ReshapeHandler(NodeHandler): else: data_type = OperationDataType.ARG + input_data = self.node.args[0]._meta_data + input_logical_shape = self.infer_logical_shape(input_data) physical_input_operand = OperationData(name=str(self.node.args[0]), type=data_type, - data=self.node.args[0]._meta_data) - physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data) + data=input_data, + logical_shape=input_logical_shape) + + output_data = self.node._meta_data + output_logical_shape = self.infer_logical_shape(output_data) + physical_output = OperationData(name=str(self.node), + type=OperationDataType.OUTPUT, + data=output_data, + logical_shape=output_logical_shape) mapping = {"input": physical_input_operand, "output": physical_output} diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/strategy_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/strategy_generator.py index d67ef1f49..ca17fbaf4 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/strategy_generator.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/strategy_generator.py @@ -81,9 +81,10 @@ class StrategyGenerator(ABC): for logical_shape, dim_partition_dict_element in zip(op_data.logical_shape, dim_partition_dict): dim_size = len(logical_shape) dim_partition_dict_element = convert_dim_partition_dict(dim_size, dim_partition_dict_element) - sharding_spec = ShardingSpec(device_mesh=self.device_mesh, - entire_shape=logical_shape, - dim_partition_dict=dim_partition_dict_element) + sharding_spec_element = ShardingSpec(device_mesh=self.device_mesh, + entire_shape=logical_shape, + dim_partition_dict=dim_partition_dict_element) + sharding_spec.append(sharding_spec_element) else: assert isinstance( op_data.data, torch.Tensor @@ -193,18 +194,40 @@ class StrategyGenerator(ABC): Args: strategy (ShardingStrategy): the ShardingStrategy generated. key (str): the name of the operation data defined by the generator. - """ op_data = self.op_data[key] - sharded_shape = strategy.sharding_specs[op_data].get_sharded_shape_per_device() - if len(sharded_shape) == 0: - num_elements = 1 + def _compute_size_in_bytes_helper(sharding_spec, meta_data): + sharded_shape = sharding_spec.get_sharded_shape_per_device() + if len(sharded_shape) == 0: + num_elements = 1 + else: + num_elements = reduce(operator.mul, sharded_shape) + dtype = getattr(meta_data, 'dtype') + size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size() + return num_elements * size_per_elem_bytes + + if isinstance(op_data.data, tuple): + assert isinstance(strategy.sharding_specs[op_data], list), \ + 'sharding_spec of op_data should be a list of sharding specs if op_data.data is a tuple.' + total_bytes = 0 + for index, sharding_spec in enumerate(strategy.sharding_specs[op_data]): + meta_data = op_data.data[index] + if isinstance(meta_data, torch.Tensor): + element_bytes = _compute_size_in_bytes_helper(sharding_spec, meta_data) + else: + # if meta_data is not a tensor, we count the memroy as 0 + element_bytes = 0 + total_bytes += element_bytes + else: - num_elements = reduce(operator.mul, sharded_shape) - dtype = self.op_data[key].data.dtype - size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size() - return num_elements * size_per_elem_bytes + if isinstance(op_data.data, torch.Tensor): + total_bytes = _compute_size_in_bytes_helper(strategy.sharding_specs[op_data], op_data.data) + else: + # if op_data.data is not a tensor, we count the memroy as 0 + total_bytes = 0 + + return total_bytes def generate(self) -> List[ShardingStrategy]: """ diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/unary_elementwise_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/unary_elementwise_handler.py index b99d4a071..334528019 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/unary_elementwise_handler.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/unary_elementwise_handler.py @@ -10,6 +10,8 @@ from .strategy import StrategyGenerator, UnaryElementwiseGenerator __all__ = ['UnaryElementwiseHandler'] +@operator_registry.register(torch.Tensor.to) +@operator_registry.register(torch.Tensor.type) @operator_registry.register(torch.abs) @operator_registry.register(torch.nn.ReLU) class UnaryElementwiseHandler(NodeHandler):