|
|
|
@ -61,6 +61,23 @@ class Solver:
|
|
|
|
|
# The last objective value of the best ILP solution.
|
|
|
|
|
self.last_objective = None
|
|
|
|
|
|
|
|
|
|
def _recover_merged_node_strategy(self):
|
|
|
|
|
'''
|
|
|
|
|
During cost graph constructing, some nodes, such as unary element-wise node or ReshapeOp, were merged into the previous node.
|
|
|
|
|
Therefore, the index of those strategies are copied from the previous node. This method is used to recover the strategy index of those merged
|
|
|
|
|
node.
|
|
|
|
|
'''
|
|
|
|
|
for node_index, node in enumerate(self.graph.nodes):
|
|
|
|
|
if node.strategies_vector.check_merge():
|
|
|
|
|
# the merged node has only one input, and its strategies follow the input sharding strategy
|
|
|
|
|
input_strategies_vector = node.args[0].strategies_vector
|
|
|
|
|
input_best_strategy_index = self.last_s_val[node_index - 1]
|
|
|
|
|
input_sharding_spec = input_strategies_vector[input_best_strategy_index].output_sharding_spec
|
|
|
|
|
for strategy_index, strategy in enumerate(node.strategies_vector):
|
|
|
|
|
if strategy.input_shardings[0].sharding_sequence == input_sharding_spec.sharding_sequence:
|
|
|
|
|
self.last_s_val[node_index] = strategy_index
|
|
|
|
|
break
|
|
|
|
|
|
|
|
|
|
def _generate_node_index_dict(self) -> Dict[Node, int]:
|
|
|
|
|
node_index_dict = {}
|
|
|
|
|
for index, strategies_vector in enumerate(self.leaf_strategies):
|
|
|
|
@ -411,13 +428,14 @@ class Solver:
|
|
|
|
|
if verbose and r[idx][e_val[idx]] > 0:
|
|
|
|
|
print(f"Edge cost {(i, j)} : {r[idx][e_val[idx]]}")
|
|
|
|
|
|
|
|
|
|
self.last_s_val = s_val
|
|
|
|
|
self.last_s_val = list(s_val)
|
|
|
|
|
self._recover_merged_node_strategy()
|
|
|
|
|
self.last_objective = objective
|
|
|
|
|
|
|
|
|
|
if objective > INFINITY_COST:
|
|
|
|
|
warnings.warn("Detect unexpected behaviors in the auto-sharding pass.")
|
|
|
|
|
|
|
|
|
|
return s_val, e_val, objective, status
|
|
|
|
|
return self.last_s_val, e_val, self.last_objective, status
|
|
|
|
|
|
|
|
|
|
def call_solver_serialized_args(self):
|
|
|
|
|
"""
|
|
|
|
|