|
|
|
@ -61,6 +61,23 @@ class Solver:
|
|
|
|
|
# The last objective value of the best ILP solution. |
|
|
|
|
self.last_objective = None |
|
|
|
|
|
|
|
|
|
def _recover_merged_node_strategy(self): |
|
|
|
|
''' |
|
|
|
|
During cost graph constructing, some nodes, such as unary element-wise node or ReshapeOp, were merged into the previous node. |
|
|
|
|
Therefore, the index of those strategies are copied from the previous node. This method is used to recover the strategy index of those merged |
|
|
|
|
node. |
|
|
|
|
''' |
|
|
|
|
for node_index, node in enumerate(self.graph.nodes): |
|
|
|
|
if node.strategies_vector.check_merge(): |
|
|
|
|
# the merged node has only one input, and its strategies follow the input sharding strategy |
|
|
|
|
input_strategies_vector = node.args[0].strategies_vector |
|
|
|
|
input_best_strategy_index = self.last_s_val[node_index - 1] |
|
|
|
|
input_sharding_spec = input_strategies_vector[input_best_strategy_index].output_sharding_spec |
|
|
|
|
for strategy_index, strategy in enumerate(node.strategies_vector): |
|
|
|
|
if strategy.input_shardings[0].sharding_sequence == input_sharding_spec.sharding_sequence: |
|
|
|
|
self.last_s_val[node_index] = strategy_index |
|
|
|
|
break |
|
|
|
|
|
|
|
|
|
def _generate_node_index_dict(self) -> Dict[Node, int]: |
|
|
|
|
node_index_dict = {} |
|
|
|
|
for index, strategies_vector in enumerate(self.leaf_strategies): |
|
|
|
@ -411,13 +428,14 @@ class Solver:
|
|
|
|
|
if verbose and r[idx][e_val[idx]] > 0: |
|
|
|
|
print(f"Edge cost {(i, j)} : {r[idx][e_val[idx]]}") |
|
|
|
|
|
|
|
|
|
self.last_s_val = s_val |
|
|
|
|
self.last_s_val = list(s_val) |
|
|
|
|
self._recover_merged_node_strategy() |
|
|
|
|
self.last_objective = objective |
|
|
|
|
|
|
|
|
|
if objective > INFINITY_COST: |
|
|
|
|
warnings.warn("Detect unexpected behaviors in the auto-sharding pass.") |
|
|
|
|
|
|
|
|
|
return s_val, e_val, objective, status |
|
|
|
|
return self.last_s_val, e_val, self.last_objective, status |
|
|
|
|
|
|
|
|
|
def call_solver_serialized_args(self): |
|
|
|
|
""" |
|
|
|
|