mirror of https://github.com/hpcaitech/ColossalAI
[autoparallel] change the following nodes strategies generation logic (#1636)
* [autoparallel] change the following nodes strategies generation logic * fix unit testpull/1654/head
parent
59f100510a
commit
03978aad45
|
@ -48,22 +48,23 @@ class UnaryElementwiseHandler(OperatorHandler):
|
||||||
# For element-wise function, we keep the sharding spec of output node same as
|
# For element-wise function, we keep the sharding spec of output node same as
|
||||||
# the input. Therefore, the different strategies of input node with same
|
# the input. Therefore, the different strategies of input node with same
|
||||||
# output sharding spec will generate same strategy for element-wise function.
|
# output sharding spec will generate same strategy for element-wise function.
|
||||||
sharding_spec_checklist = []
|
|
||||||
for strategy in self.input_node.strategies_vector:
|
for index, strategy in enumerate(self.input_node.strategies_vector):
|
||||||
# It looks a little bit confusing, the input of the processing node
|
# It looks a little bit confusing, the input of the processing node
|
||||||
# is the output of the input_node.
|
# is the output of the input_node.
|
||||||
input_sharding_spec = strategy.output_sharding_spec
|
input_sharding_spec = strategy.output_sharding_spec
|
||||||
assert isinstance(input_sharding_spec, ShardingSpec), f'The input node should NOT be a tuple of tensor.'
|
assert isinstance(input_sharding_spec, ShardingSpec), f'The input node should NOT be a tuple of tensor.'
|
||||||
if input_sharding_spec in sharding_spec_checklist:
|
|
||||||
continue
|
|
||||||
sharding_spec_checklist.append(input_sharding_spec)
|
|
||||||
dim_partition_dict = deepcopy(input_sharding_spec.dim_partition_dict)
|
dim_partition_dict = deepcopy(input_sharding_spec.dim_partition_dict)
|
||||||
try:
|
try:
|
||||||
output_sharding_spec = self._generate_sharding_spec(self.output_data, dim_partition_dict)
|
output_sharding_spec = self._generate_sharding_spec(self.output_data, dim_partition_dict)
|
||||||
except AssertionError as e:
|
except AssertionError as e:
|
||||||
warnings.warn(f'{e}')
|
warnings.warn(f'{e}')
|
||||||
continue
|
continue
|
||||||
name = f'{input_sharding_spec.sharding_sequence} -> {output_sharding_spec.sharding_sequence}'
|
# add index into name to pass the duplicated check
|
||||||
|
# we keep same strategies with different name for node merging, and it will not increase the searching space,
|
||||||
|
# because in solver, this node will be merged into other nodes, and solver will not create a new variable for this node.
|
||||||
|
name = f'{input_sharding_spec.sharding_sequence} -> {output_sharding_spec.sharding_sequence}_{index}'
|
||||||
# TODO: use meta_info_prop to profile memory cost and compute cost
|
# TODO: use meta_info_prop to profile memory cost and compute cost
|
||||||
compute_cost = self.output_data.numel()
|
compute_cost = self.output_data.numel()
|
||||||
memory_cost = 0
|
memory_cost = 0
|
||||||
|
|
|
@ -59,7 +59,7 @@ def test_strategies_constructor():
|
||||||
assert strategies_constructor.leaf_strategies[0][0].name == 'Replica Placeholder'
|
assert strategies_constructor.leaf_strategies[0][0].name == 'Replica Placeholder'
|
||||||
|
|
||||||
# Second node is mul which is a element-wise node, therefore the output sharding spec is same as input sharding spec.
|
# Second node is mul which is a element-wise node, therefore the output sharding spec is same as input sharding spec.
|
||||||
assert strategies_constructor.leaf_strategies[1][0].name == '[R, R, R, R] -> [R, R, R, R]'
|
assert strategies_constructor.leaf_strategies[1][0].name == '[R, R, R, R] -> [R, R, R, R]_0'
|
||||||
|
|
||||||
# Third node is conv.
|
# Third node is conv.
|
||||||
conv_check_list = deepcopy(CONV_STRATEGIES_LIST)
|
conv_check_list = deepcopy(CONV_STRATEGIES_LIST)
|
||||||
|
@ -79,7 +79,7 @@ def test_strategies_constructor():
|
||||||
|
|
||||||
# Second node is mul which is a element-wise node, therefore the output sharding spec is same as input sharding spec.
|
# Second node is mul which is a element-wise node, therefore the output sharding spec is same as input sharding spec.
|
||||||
mul = nodes[1]
|
mul = nodes[1]
|
||||||
assert strategies_constructor.strategy_map[mul][0].name == '[R, R, R, R] -> [R, R, R, R]'
|
assert strategies_constructor.strategy_map[mul][0].name == '[R, R, R, R] -> [R, R, R, R]_0'
|
||||||
|
|
||||||
# Third node is conv.
|
# Third node is conv.
|
||||||
conv = nodes[2]
|
conv = nodes[2]
|
||||||
|
|
Loading…
Reference in New Issue