[autoparallel] change the following nodes strategies generation logic (#1636)

* [autoparallel] change the following nodes strategies generation logic

* fix unit test
pull/1654/head
YuliangLiu0306 2022-09-27 11:20:52 +08:00 committed by GitHub
parent 59f100510a
commit 03978aad45
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 9 additions and 8 deletions

View File

@ -48,22 +48,23 @@ class UnaryElementwiseHandler(OperatorHandler):
# For element-wise function, we keep the sharding spec of output node same as
# the input. Therefore, the different strategies of input node with same
# output sharding spec will generate same strategy for element-wise function.
sharding_spec_checklist = []
for strategy in self.input_node.strategies_vector:
for index, strategy in enumerate(self.input_node.strategies_vector):
# It looks a little bit confusing, the input of the processing node
# is the output of the input_node.
input_sharding_spec = strategy.output_sharding_spec
assert isinstance(input_sharding_spec, ShardingSpec), f'The input node should NOT be a tuple of tensor.'
if input_sharding_spec in sharding_spec_checklist:
continue
sharding_spec_checklist.append(input_sharding_spec)
dim_partition_dict = deepcopy(input_sharding_spec.dim_partition_dict)
try:
output_sharding_spec = self._generate_sharding_spec(self.output_data, dim_partition_dict)
except AssertionError as e:
warnings.warn(f'{e}')
continue
name = f'{input_sharding_spec.sharding_sequence} -> {output_sharding_spec.sharding_sequence}'
# add index into name to pass the duplicated check
# we keep same strategies with different name for node merging, and it will not increase the searching space,
# because in solver, this node will be merged into other nodes, and solver will not create a new variable for this node.
name = f'{input_sharding_spec.sharding_sequence} -> {output_sharding_spec.sharding_sequence}_{index}'
# TODO: use meta_info_prop to profile memory cost and compute cost
compute_cost = self.output_data.numel()
memory_cost = 0

View File

@ -59,7 +59,7 @@ def test_strategies_constructor():
assert strategies_constructor.leaf_strategies[0][0].name == 'Replica Placeholder'
# Second node is mul which is a element-wise node, therefore the output sharding spec is same as input sharding spec.
assert strategies_constructor.leaf_strategies[1][0].name == '[R, R, R, R] -> [R, R, R, R]'
assert strategies_constructor.leaf_strategies[1][0].name == '[R, R, R, R] -> [R, R, R, R]_0'
# Third node is conv.
conv_check_list = deepcopy(CONV_STRATEGIES_LIST)
@ -79,7 +79,7 @@ def test_strategies_constructor():
# Second node is mul which is a element-wise node, therefore the output sharding spec is same as input sharding spec.
mul = nodes[1]
assert strategies_constructor.strategy_map[mul][0].name == '[R, R, R, R] -> [R, R, R, R]'
assert strategies_constructor.strategy_map[mul][0].name == '[R, R, R, R] -> [R, R, R, R]_0'
# Third node is conv.
conv = nodes[2]