mirror of https://github.com/hpcaitech/ColossalAI
[autoparallel] recover the merged node strategy index (#1613)
parent
d6b01feb66
commit
bf77d3ab65
|
@ -61,6 +61,23 @@ class Solver:
|
|||
# The last objective value of the best ILP solution.
|
||||
self.last_objective = None
|
||||
|
||||
def _recover_merged_node_strategy(self):
|
||||
'''
|
||||
During cost graph constructing, some nodes, such as unary element-wise node or ReshapeOp, were merged into the previous node.
|
||||
Therefore, the index of those strategies are copied from the previous node. This method is used to recover the strategy index of those merged
|
||||
node.
|
||||
'''
|
||||
for node_index, node in enumerate(self.graph.nodes):
|
||||
if node.strategies_vector.check_merge():
|
||||
# the merged node has only one input, and its strategies follow the input sharding strategy
|
||||
input_strategies_vector = node.args[0].strategies_vector
|
||||
input_best_strategy_index = self.last_s_val[node_index - 1]
|
||||
input_sharding_spec = input_strategies_vector[input_best_strategy_index].output_sharding_spec
|
||||
for strategy_index, strategy in enumerate(node.strategies_vector):
|
||||
if strategy.input_shardings[0].sharding_sequence == input_sharding_spec.sharding_sequence:
|
||||
self.last_s_val[node_index] = strategy_index
|
||||
break
|
||||
|
||||
def _generate_node_index_dict(self) -> Dict[Node, int]:
|
||||
node_index_dict = {}
|
||||
for index, strategies_vector in enumerate(self.leaf_strategies):
|
||||
|
@ -411,13 +428,14 @@ class Solver:
|
|||
if verbose and r[idx][e_val[idx]] > 0:
|
||||
print(f"Edge cost {(i, j)} : {r[idx][e_val[idx]]}")
|
||||
|
||||
self.last_s_val = s_val
|
||||
self.last_s_val = list(s_val)
|
||||
self._recover_merged_node_strategy()
|
||||
self.last_objective = objective
|
||||
|
||||
if objective > INFINITY_COST:
|
||||
warnings.warn("Detect unexpected behaviors in the auto-sharding pass.")
|
||||
|
||||
return s_val, e_val, objective, status
|
||||
return self.last_s_val, e_val, self.last_objective, status
|
||||
|
||||
def call_solver_serialized_args(self):
|
||||
"""
|
||||
|
|
|
@ -80,21 +80,20 @@ def test_cost_graph():
|
|||
gm.recompile()
|
||||
graph_analyser = GraphAnalyser(gm)
|
||||
liveness_list = graph_analyser.liveness_analysis()
|
||||
# print(len(liveness_dict[0].unique_live_vars))
|
||||
# assert False
|
||||
solver_options = SolverOptions(fast=True)
|
||||
strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
|
||||
strategies_constructor.build_strategies_and_cost()
|
||||
|
||||
cost_graph = CostGraph(strategies_constructor.leaf_strategies)
|
||||
cost_graph.simplify_graph()
|
||||
solver = Solver(gm.graph, strategies_constructor, cost_graph, graph_analyser, memory_budget=1620017824.0)
|
||||
# solver = Solver(gm.graph, strategies_constructor, cost_graph, graph_analyser)
|
||||
solver = Solver(gm.graph, strategies_constructor, cost_graph, graph_analyser)
|
||||
|
||||
ret = solver.call_solver_serialized_args()
|
||||
print(ret)
|
||||
strategies_list = list(ret[0])
|
||||
print(strategies_list)
|
||||
print(ret[0])
|
||||
solver._recover_merged_node_strategy()
|
||||
print(solver.last_s_val)
|
||||
strategies_list = solver.last_s_val
|
||||
|
||||
computation_cost = 0
|
||||
communication_cost = 0
|
||||
communication_cost_bn = 0
|
||||
|
@ -102,10 +101,6 @@ def test_cost_graph():
|
|||
for index, node in enumerate(graph.nodes):
|
||||
if node.op == 'call_module':
|
||||
submod = node.graph.owning_module.get_submodule(node.target)
|
||||
if type(submod) in ELEMENTWISE_MODULE_OP:
|
||||
input_spec = node.args[0].strategies_vector[strategies_list[index]].output_sharding_spec
|
||||
print(node.name, input_spec)
|
||||
continue
|
||||
if type(submod) in BATCHNORM_MODULE_OP:
|
||||
communication_cost_bn += node.strategies_vector[strategies_list[index]].communication_cost
|
||||
print(node.name, node.strategies_vector[strategies_list[index]].name)
|
||||
|
|
Loading…
Reference in New Issue