mirror of https://github.com/hpcaitech/ColossalAI
[autoparallel] change the following nodes strategies generation logic (#1636)
* [autoparallel] change the following nodes strategies generation logic * fix unit testpull/1654/head
parent
59f100510a
commit
03978aad45
|
@ -48,22 +48,23 @@ class UnaryElementwiseHandler(OperatorHandler):
|
|||
# For element-wise function, we keep the sharding spec of output node same as
|
||||
# the input. Therefore, the different strategies of input node with same
|
||||
# output sharding spec will generate same strategy for element-wise function.
|
||||
sharding_spec_checklist = []
|
||||
for strategy in self.input_node.strategies_vector:
|
||||
|
||||
for index, strategy in enumerate(self.input_node.strategies_vector):
|
||||
# It looks a little bit confusing, the input of the processing node
|
||||
# is the output of the input_node.
|
||||
input_sharding_spec = strategy.output_sharding_spec
|
||||
assert isinstance(input_sharding_spec, ShardingSpec), f'The input node should NOT be a tuple of tensor.'
|
||||
if input_sharding_spec in sharding_spec_checklist:
|
||||
continue
|
||||
sharding_spec_checklist.append(input_sharding_spec)
|
||||
|
||||
dim_partition_dict = deepcopy(input_sharding_spec.dim_partition_dict)
|
||||
try:
|
||||
output_sharding_spec = self._generate_sharding_spec(self.output_data, dim_partition_dict)
|
||||
except AssertionError as e:
|
||||
warnings.warn(f'{e}')
|
||||
continue
|
||||
name = f'{input_sharding_spec.sharding_sequence} -> {output_sharding_spec.sharding_sequence}'
|
||||
# add index into name to pass the duplicated check
|
||||
# we keep same strategies with different name for node merging, and it will not increase the searching space,
|
||||
# because in solver, this node will be merged into other nodes, and solver will not create a new variable for this node.
|
||||
name = f'{input_sharding_spec.sharding_sequence} -> {output_sharding_spec.sharding_sequence}_{index}'
|
||||
# TODO: use meta_info_prop to profile memory cost and compute cost
|
||||
compute_cost = self.output_data.numel()
|
||||
memory_cost = 0
|
||||
|
|
|
@ -59,7 +59,7 @@ def test_strategies_constructor():
|
|||
assert strategies_constructor.leaf_strategies[0][0].name == 'Replica Placeholder'
|
||||
|
||||
# Second node is mul which is a element-wise node, therefore the output sharding spec is same as input sharding spec.
|
||||
assert strategies_constructor.leaf_strategies[1][0].name == '[R, R, R, R] -> [R, R, R, R]'
|
||||
assert strategies_constructor.leaf_strategies[1][0].name == '[R, R, R, R] -> [R, R, R, R]_0'
|
||||
|
||||
# Third node is conv.
|
||||
conv_check_list = deepcopy(CONV_STRATEGIES_LIST)
|
||||
|
@ -79,7 +79,7 @@ def test_strategies_constructor():
|
|||
|
||||
# Second node is mul which is a element-wise node, therefore the output sharding spec is same as input sharding spec.
|
||||
mul = nodes[1]
|
||||
assert strategies_constructor.strategy_map[mul][0].name == '[R, R, R, R] -> [R, R, R, R]'
|
||||
assert strategies_constructor.strategy_map[mul][0].name == '[R, R, R, R] -> [R, R, R, R]_0'
|
||||
|
||||
# Third node is conv.
|
||||
conv = nodes[2]
|
||||
|
|
Loading…
Reference in New Issue