[autoparallel] recover the merged node strategy index (#1613)

pull/1617/head^2
YuliangLiu0306 2 years ago committed by GitHub
parent d6b01feb66
commit bf77d3ab65
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -61,6 +61,23 @@ class Solver:
# The last objective value of the best ILP solution.
self.last_objective = None
def _recover_merged_node_strategy(self):
'''
During cost graph constructing, some nodes, such as unary element-wise node or ReshapeOp, were merged into the previous node.
Therefore, the index of those strategies are copied from the previous node. This method is used to recover the strategy index of those merged
node.
'''
for node_index, node in enumerate(self.graph.nodes):
if node.strategies_vector.check_merge():
# the merged node has only one input, and its strategies follow the input sharding strategy
input_strategies_vector = node.args[0].strategies_vector
input_best_strategy_index = self.last_s_val[node_index - 1]
input_sharding_spec = input_strategies_vector[input_best_strategy_index].output_sharding_spec
for strategy_index, strategy in enumerate(node.strategies_vector):
if strategy.input_shardings[0].sharding_sequence == input_sharding_spec.sharding_sequence:
self.last_s_val[node_index] = strategy_index
break
def _generate_node_index_dict(self) -> Dict[Node, int]:
node_index_dict = {}
for index, strategies_vector in enumerate(self.leaf_strategies):
@ -411,13 +428,14 @@ class Solver:
if verbose and r[idx][e_val[idx]] > 0:
print(f"Edge cost {(i, j)} : {r[idx][e_val[idx]]}")
self.last_s_val = s_val
self.last_s_val = list(s_val)
self._recover_merged_node_strategy()
self.last_objective = objective
if objective > INFINITY_COST:
warnings.warn("Detect unexpected behaviors in the auto-sharding pass.")
return s_val, e_val, objective, status
return self.last_s_val, e_val, self.last_objective, status
def call_solver_serialized_args(self):
"""

@ -80,21 +80,20 @@ def test_cost_graph():
gm.recompile()
graph_analyser = GraphAnalyser(gm)
liveness_list = graph_analyser.liveness_analysis()
# print(len(liveness_dict[0].unique_live_vars))
# assert False
solver_options = SolverOptions(fast=True)
strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
strategies_constructor.build_strategies_and_cost()
cost_graph = CostGraph(strategies_constructor.leaf_strategies)
cost_graph.simplify_graph()
solver = Solver(gm.graph, strategies_constructor, cost_graph, graph_analyser, memory_budget=1620017824.0)
# solver = Solver(gm.graph, strategies_constructor, cost_graph, graph_analyser)
solver = Solver(gm.graph, strategies_constructor, cost_graph, graph_analyser)
ret = solver.call_solver_serialized_args()
print(ret)
strategies_list = list(ret[0])
print(strategies_list)
print(ret[0])
solver._recover_merged_node_strategy()
print(solver.last_s_val)
strategies_list = solver.last_s_val
computation_cost = 0
communication_cost = 0
communication_cost_bn = 0
@ -102,10 +101,6 @@ def test_cost_graph():
for index, node in enumerate(graph.nodes):
if node.op == 'call_module':
submod = node.graph.owning_module.get_submodule(node.target)
if type(submod) in ELEMENTWISE_MODULE_OP:
input_spec = node.args[0].strategies_vector[strategies_list[index]].output_sharding_spec
print(node.name, input_spec)
continue
if type(submod) in BATCHNORM_MODULE_OP:
communication_cost_bn += node.strategies_vector[strategies_list[index]].communication_cost
print(node.name, node.strategies_vector[strategies_list[index]].name)

Loading…
Cancel
Save