mirror of https://github.com/hpcaitech/ColossalAI
[autoparallel] recover the merged node strategy index (#1613)
parent
d6b01feb66
commit
bf77d3ab65
|
@ -61,6 +61,23 @@ class Solver:
|
||||||
# The last objective value of the best ILP solution.
|
# The last objective value of the best ILP solution.
|
||||||
self.last_objective = None
|
self.last_objective = None
|
||||||
|
|
||||||
|
def _recover_merged_node_strategy(self):
|
||||||
|
'''
|
||||||
|
During cost graph constructing, some nodes, such as unary element-wise node or ReshapeOp, were merged into the previous node.
|
||||||
|
Therefore, the index of those strategies are copied from the previous node. This method is used to recover the strategy index of those merged
|
||||||
|
node.
|
||||||
|
'''
|
||||||
|
for node_index, node in enumerate(self.graph.nodes):
|
||||||
|
if node.strategies_vector.check_merge():
|
||||||
|
# the merged node has only one input, and its strategies follow the input sharding strategy
|
||||||
|
input_strategies_vector = node.args[0].strategies_vector
|
||||||
|
input_best_strategy_index = self.last_s_val[node_index - 1]
|
||||||
|
input_sharding_spec = input_strategies_vector[input_best_strategy_index].output_sharding_spec
|
||||||
|
for strategy_index, strategy in enumerate(node.strategies_vector):
|
||||||
|
if strategy.input_shardings[0].sharding_sequence == input_sharding_spec.sharding_sequence:
|
||||||
|
self.last_s_val[node_index] = strategy_index
|
||||||
|
break
|
||||||
|
|
||||||
def _generate_node_index_dict(self) -> Dict[Node, int]:
|
def _generate_node_index_dict(self) -> Dict[Node, int]:
|
||||||
node_index_dict = {}
|
node_index_dict = {}
|
||||||
for index, strategies_vector in enumerate(self.leaf_strategies):
|
for index, strategies_vector in enumerate(self.leaf_strategies):
|
||||||
|
@ -411,13 +428,14 @@ class Solver:
|
||||||
if verbose and r[idx][e_val[idx]] > 0:
|
if verbose and r[idx][e_val[idx]] > 0:
|
||||||
print(f"Edge cost {(i, j)} : {r[idx][e_val[idx]]}")
|
print(f"Edge cost {(i, j)} : {r[idx][e_val[idx]]}")
|
||||||
|
|
||||||
self.last_s_val = s_val
|
self.last_s_val = list(s_val)
|
||||||
|
self._recover_merged_node_strategy()
|
||||||
self.last_objective = objective
|
self.last_objective = objective
|
||||||
|
|
||||||
if objective > INFINITY_COST:
|
if objective > INFINITY_COST:
|
||||||
warnings.warn("Detect unexpected behaviors in the auto-sharding pass.")
|
warnings.warn("Detect unexpected behaviors in the auto-sharding pass.")
|
||||||
|
|
||||||
return s_val, e_val, objective, status
|
return self.last_s_val, e_val, self.last_objective, status
|
||||||
|
|
||||||
def call_solver_serialized_args(self):
|
def call_solver_serialized_args(self):
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -80,21 +80,20 @@ def test_cost_graph():
|
||||||
gm.recompile()
|
gm.recompile()
|
||||||
graph_analyser = GraphAnalyser(gm)
|
graph_analyser = GraphAnalyser(gm)
|
||||||
liveness_list = graph_analyser.liveness_analysis()
|
liveness_list = graph_analyser.liveness_analysis()
|
||||||
# print(len(liveness_dict[0].unique_live_vars))
|
|
||||||
# assert False
|
|
||||||
solver_options = SolverOptions(fast=True)
|
solver_options = SolverOptions(fast=True)
|
||||||
strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
|
strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
|
||||||
strategies_constructor.build_strategies_and_cost()
|
strategies_constructor.build_strategies_and_cost()
|
||||||
|
|
||||||
cost_graph = CostGraph(strategies_constructor.leaf_strategies)
|
cost_graph = CostGraph(strategies_constructor.leaf_strategies)
|
||||||
cost_graph.simplify_graph()
|
cost_graph.simplify_graph()
|
||||||
solver = Solver(gm.graph, strategies_constructor, cost_graph, graph_analyser, memory_budget=1620017824.0)
|
solver = Solver(gm.graph, strategies_constructor, cost_graph, graph_analyser)
|
||||||
# solver = Solver(gm.graph, strategies_constructor, cost_graph, graph_analyser)
|
|
||||||
|
|
||||||
ret = solver.call_solver_serialized_args()
|
ret = solver.call_solver_serialized_args()
|
||||||
print(ret)
|
print(ret[0])
|
||||||
strategies_list = list(ret[0])
|
solver._recover_merged_node_strategy()
|
||||||
print(strategies_list)
|
print(solver.last_s_val)
|
||||||
|
strategies_list = solver.last_s_val
|
||||||
|
|
||||||
computation_cost = 0
|
computation_cost = 0
|
||||||
communication_cost = 0
|
communication_cost = 0
|
||||||
communication_cost_bn = 0
|
communication_cost_bn = 0
|
||||||
|
@ -102,10 +101,6 @@ def test_cost_graph():
|
||||||
for index, node in enumerate(graph.nodes):
|
for index, node in enumerate(graph.nodes):
|
||||||
if node.op == 'call_module':
|
if node.op == 'call_module':
|
||||||
submod = node.graph.owning_module.get_submodule(node.target)
|
submod = node.graph.owning_module.get_submodule(node.target)
|
||||||
if type(submod) in ELEMENTWISE_MODULE_OP:
|
|
||||||
input_spec = node.args[0].strategies_vector[strategies_list[index]].output_sharding_spec
|
|
||||||
print(node.name, input_spec)
|
|
||||||
continue
|
|
||||||
if type(submod) in BATCHNORM_MODULE_OP:
|
if type(submod) in BATCHNORM_MODULE_OP:
|
||||||
communication_cost_bn += node.strategies_vector[strategies_list[index]].communication_cost
|
communication_cost_bn += node.strategies_vector[strategies_list[index]].communication_cost
|
||||||
print(node.name, node.strategies_vector[strategies_list[index]].name)
|
print(node.name, node.strategies_vector[strategies_list[index]].name)
|
||||||
|
|
Loading…
Reference in New Issue