diff --git a/colossalai/auto_parallel/solver/op_handler/unary_elementwise_handler.py b/colossalai/auto_parallel/solver/op_handler/unary_elementwise_handler.py index c6fbc3d2c..205923fa0 100644 --- a/colossalai/auto_parallel/solver/op_handler/unary_elementwise_handler.py +++ b/colossalai/auto_parallel/solver/op_handler/unary_elementwise_handler.py @@ -48,22 +48,23 @@ class UnaryElementwiseHandler(OperatorHandler): # For element-wise function, we keep the sharding spec of output node same as # the input. Therefore, the different strategies of input node with same # output sharding spec will generate same strategy for element-wise function. - sharding_spec_checklist = [] - for strategy in self.input_node.strategies_vector: + + for index, strategy in enumerate(self.input_node.strategies_vector): # It looks a little bit confusing, the input of the processing node # is the output of the input_node. input_sharding_spec = strategy.output_sharding_spec assert isinstance(input_sharding_spec, ShardingSpec), f'The input node should NOT be a tuple of tensor.' - if input_sharding_spec in sharding_spec_checklist: - continue - sharding_spec_checklist.append(input_sharding_spec) + dim_partition_dict = deepcopy(input_sharding_spec.dim_partition_dict) try: output_sharding_spec = self._generate_sharding_spec(self.output_data, dim_partition_dict) except AssertionError as e: warnings.warn(f'{e}') continue - name = f'{input_sharding_spec.sharding_sequence} -> {output_sharding_spec.sharding_sequence}' + # add index into name to pass the duplicated check + # we keep same strategies with different name for node merging, and it will not increase the searching space, + # because in solver, this node will be merged into other nodes, and solver will not create a new variable for this node. + name = f'{input_sharding_spec.sharding_sequence} -> {output_sharding_spec.sharding_sequence}_{index}' # TODO: use meta_info_prop to profile memory cost and compute cost compute_cost = self.output_data.numel() memory_cost = 0 diff --git a/tests/test_auto_parallel/test_strategies_constructor.py b/tests/test_auto_parallel/test_strategies_constructor.py index ce263829a..ef8dbde03 100644 --- a/tests/test_auto_parallel/test_strategies_constructor.py +++ b/tests/test_auto_parallel/test_strategies_constructor.py @@ -59,7 +59,7 @@ def test_strategies_constructor(): assert strategies_constructor.leaf_strategies[0][0].name == 'Replica Placeholder' # Second node is mul which is a element-wise node, therefore the output sharding spec is same as input sharding spec. - assert strategies_constructor.leaf_strategies[1][0].name == '[R, R, R, R] -> [R, R, R, R]' + assert strategies_constructor.leaf_strategies[1][0].name == '[R, R, R, R] -> [R, R, R, R]_0' # Third node is conv. conv_check_list = deepcopy(CONV_STRATEGIES_LIST) @@ -79,7 +79,7 @@ def test_strategies_constructor(): # Second node is mul which is a element-wise node, therefore the output sharding spec is same as input sharding spec. mul = nodes[1] - assert strategies_constructor.strategy_map[mul][0].name == '[R, R, R, R] -> [R, R, R, R]' + assert strategies_constructor.strategy_map[mul][0].name == '[R, R, R, R] -> [R, R, R, R]_0' # Third node is conv. conv = nodes[2]