From 8221fd7485772d0133cb177ef7f5dbf984d7a76e Mon Sep 17 00:00:00 2001 From: YuliangLiu0306 <72588413+YuliangLiu0306@users.noreply.github.com> Date: Thu, 12 Jan 2023 09:35:10 +0800 Subject: [PATCH] [autoparallel] update binary elementwise handler (#2451) * [autoparallel] update binary elementwise handler * polish --- .../binary_elementwise_handler.py | 27 ++++++-- .../test_binary_elementwise_handler.py | 65 ++++++++++++++----- .../test_node_handler/utils.py | 5 +- 3 files changed, 74 insertions(+), 23 deletions(-) diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/binary_elementwise_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/binary_elementwise_handler.py index f510f7477..db8f0b54d 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/binary_elementwise_handler.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/binary_elementwise_handler.py @@ -32,20 +32,32 @@ class BinaryElementwiseHandler(MetaInfoNodeHandler): return OperationDataType.ARG def _get_arg_value(idx): + non_tensor = False if isinstance(self.node.args[idx], Node): meta_data = self.node.args[idx]._meta_data + # The meta_data of node type argument could also possibly be a non-tensor object. + if not isinstance(meta_data, torch.Tensor): + assert isinstance(meta_data, (int, float)) + meta_data = torch.Tensor([meta_data]).to('meta') + non_tensor = True + else: # this is in fact a real data like int 1 # but we can deem it as meta data # as it won't affect the strategy generation assert isinstance(self.node.args[idx], (int, float)) meta_data = torch.Tensor([self.node.args[idx]]).to('meta') - return meta_data + non_tensor = True - input_meta_data = _get_arg_value(0) - other_meta_data = _get_arg_value(1) + return meta_data, non_tensor + + input_meta_data, non_tensor_input = _get_arg_value(0) + other_meta_data, non_tensor_other = _get_arg_value(1) output_meta_data = self.node._meta_data - + # we need record op_data with non-tensor data in this list, + # and filter the non-tensor op_data in post_process. + self.non_tensor_list = [] + # assert False input_op_data = OperationData(name=str(self.node.args[0]), type=_get_op_data_type(input_meta_data), data=input_meta_data, @@ -58,6 +70,10 @@ class BinaryElementwiseHandler(MetaInfoNodeHandler): type=OperationDataType.OUTPUT, data=output_meta_data, logical_shape=bcast_shape) + if non_tensor_input: + self.non_tensor_list.append(input_op_data) + if non_tensor_other: + self.non_tensor_list.append(other_op_data) mapping = {'input': input_op_data, 'other': other_op_data, 'output': output_op_data} return mapping @@ -73,9 +89,10 @@ class BinaryElementwiseHandler(MetaInfoNodeHandler): op_data_mapping = self.get_operation_data_mapping() for op_name, op_data in op_data_mapping.items(): - if not isinstance(op_data.data, torch.Tensor): + if op_data in self.non_tensor_list: # remove the sharding spec if the op_data is not a tensor, e.g. torch.pow(tensor, 2) strategy.sharding_specs.pop(op_data) + else: # convert the logical sharding spec to physical sharding spec if broadcast # e.g. torch.rand(4, 4) + torch.rand(4) diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_binary_elementwise_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_binary_elementwise_handler.py index 42430d5a2..50385c045 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_binary_elementwise_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_binary_elementwise_handler.py @@ -122,25 +122,41 @@ def check_binary_elementwise_handler_with_tensor(rank, op, other_dim, world_size assert input_sharding_spec.sharding_sequence[-1] == other_sharding_spec.sharding_sequence[-1] -def check_binary_elementwise_handler_with_int(rank, op, other_dim, world_size, port): +class BEOpModelWithNodeConst(nn.Module): + + def __init__(self, op): + super().__init__() + self.op = op + + def forward(self, x1): + const = x1.dim() + out = self.op(x1, const) + return out + + +class BEOpModelWithIntConst(nn.Module): + + def __init__(self, op, const): + super().__init__() + self.op = op + self.const = const + + def forward(self, x1): + out = self.op(x1, self.const) + return out + + +def check_binary_elementwise_handler_with_int(rank, op, other_dim, model_cls, world_size, port): disable_existing_loggers() launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - class BinaryElementwiseOpModel(nn.Module): - - def __init__(self, op, const): - super().__init__() - self.op = op - self.const = const - - def forward(self, x1): - out = self.op(x1, self.const) - return out - physical_mesh_id = torch.arange(0, 4) mesh_shape = (2, 2) device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) - model = BinaryElementwiseOpModel(op, other_dim).cuda() + if model_cls == BEOpModelWithNodeConst: + model = model_cls(op).cuda() + else: + model = model_cls(op, other_dim).cuda() x1 = torch.rand(4, 4).cuda() # the index of binary-elementwise node in computation graph node_index = 1 @@ -159,9 +175,14 @@ def check_binary_elementwise_handler_with_int(rank, op, other_dim, world_size, p tracer = ColoTracer() meta_args = {'x1': torch.rand(4, 4).to('meta')} graph = tracer.trace(model, meta_args=meta_args) + print(graph) + # assert False gm = ColoGraphModule(model, graph) - op_node = list(graph.nodes)[1] + if model_cls == BEOpModelWithNodeConst: + op_node = list(graph.nodes)[2] + else: + op_node = list(graph.nodes)[1] strategies_vector = StrategiesVector(op_node) # build handler @@ -212,7 +233,7 @@ def check_binary_elementwise_handler_with_int(rank, op, other_dim, world_size, p @parameterize('other_dim', [1, 2]) @pytest.mark.dist @rerun_if_address_is_in_use() -def test_binary_elementwise_handler(op, other_dim): +def test_binary_elementwise_handler_with_tensor(op, other_dim): world_size = 4 run_func_tensor = partial(check_binary_elementwise_handler_with_tensor, op=op, @@ -220,8 +241,19 @@ def test_binary_elementwise_handler(op, other_dim): world_size=world_size, port=free_port()) mp.spawn(run_func_tensor, nprocs=world_size) + + +@run_on_environment_flag(name='AUTO_PARALLEL') +@parameterize('op', [torch.add]) +@parameterize('other_dim', [1, 2]) +@parameterize('model_cls', [BEOpModelWithNodeConst, BEOpModelWithIntConst]) +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_binary_elementwise_handler_with_int(op, model_cls, other_dim): + world_size = 4 run_func_int = partial(check_binary_elementwise_handler_with_int, op=op, + model_cls=model_cls, other_dim=other_dim, world_size=world_size, port=free_port()) @@ -229,4 +261,5 @@ def test_binary_elementwise_handler(op, other_dim): if __name__ == '__main__': - test_binary_elementwise_handler() + test_binary_elementwise_handler_with_tensor() + test_binary_elementwise_handler_with_int() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/utils.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/utils.py index d02e1e31e..db76ed9b8 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/utils.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/utils.py @@ -90,7 +90,8 @@ def numerical_test_for_node_strategy(model: torch.nn.Module, solver_options = SolverOptions() strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options) strategies_constructor.build_strategies_and_cost() - target_node = list(graph.nodes)[node_index] + target_node = [strategies_vector.node for strategies_vector in strategies_constructor.leaf_strategies + ][node_index] if node_type == 'normal': solution_len = len(strategies_constructor.leaf_strategies) solution = [0] * solution_len @@ -112,7 +113,7 @@ def numerical_test_for_node_strategy(model: torch.nn.Module, ret = solver.call_solver_serialized_args() solution = list(ret[0]) gm, sharding_spec_dict, origin_spec_dict, comm_actions_dict = runtime_preparation_pass( - gm, solution, device_mesh) + gm, solution, device_mesh, strategies_constructor) gm = runtime_apply_pass(gm) gm.recompile()