From bf77d3ab6577f75254f4763584de13ee44899148 Mon Sep 17 00:00:00 2001 From: YuliangLiu0306 <72588413+YuliangLiu0306@users.noreply.github.com> Date: Fri, 23 Sep 2022 11:52:42 +0800 Subject: [PATCH] [autoparallel] recover the merged node strategy index (#1613) --- colossalai/auto_parallel/solver/solver.py | 22 +++++++++++++++++-- .../test_solver_with_resnet.py | 17 +++++--------- 2 files changed, 26 insertions(+), 13 deletions(-) diff --git a/colossalai/auto_parallel/solver/solver.py b/colossalai/auto_parallel/solver/solver.py index 63c35c2fc..50f5c696f 100644 --- a/colossalai/auto_parallel/solver/solver.py +++ b/colossalai/auto_parallel/solver/solver.py @@ -61,6 +61,23 @@ class Solver: # The last objective value of the best ILP solution. self.last_objective = None + def _recover_merged_node_strategy(self): + ''' + During cost graph constructing, some nodes, such as unary element-wise node or ReshapeOp, were merged into the previous node. + Therefore, the index of those strategies are copied from the previous node. This method is used to recover the strategy index of those merged + node. + ''' + for node_index, node in enumerate(self.graph.nodes): + if node.strategies_vector.check_merge(): + # the merged node has only one input, and its strategies follow the input sharding strategy + input_strategies_vector = node.args[0].strategies_vector + input_best_strategy_index = self.last_s_val[node_index - 1] + input_sharding_spec = input_strategies_vector[input_best_strategy_index].output_sharding_spec + for strategy_index, strategy in enumerate(node.strategies_vector): + if strategy.input_shardings[0].sharding_sequence == input_sharding_spec.sharding_sequence: + self.last_s_val[node_index] = strategy_index + break + def _generate_node_index_dict(self) -> Dict[Node, int]: node_index_dict = {} for index, strategies_vector in enumerate(self.leaf_strategies): @@ -411,13 +428,14 @@ class Solver: if verbose and r[idx][e_val[idx]] > 0: print(f"Edge cost {(i, j)} : {r[idx][e_val[idx]]}") - self.last_s_val = s_val + self.last_s_val = list(s_val) + self._recover_merged_node_strategy() self.last_objective = objective if objective > INFINITY_COST: warnings.warn("Detect unexpected behaviors in the auto-sharding pass.") - return s_val, e_val, objective, status + return self.last_s_val, e_val, self.last_objective, status def call_solver_serialized_args(self): """ diff --git a/tests/test_auto_parallel/test_solver_with_resnet.py b/tests/test_auto_parallel/test_solver_with_resnet.py index 8d133886a..a46ceb700 100644 --- a/tests/test_auto_parallel/test_solver_with_resnet.py +++ b/tests/test_auto_parallel/test_solver_with_resnet.py @@ -80,21 +80,20 @@ def test_cost_graph(): gm.recompile() graph_analyser = GraphAnalyser(gm) liveness_list = graph_analyser.liveness_analysis() - # print(len(liveness_dict[0].unique_live_vars)) - # assert False solver_options = SolverOptions(fast=True) strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options) strategies_constructor.build_strategies_and_cost() cost_graph = CostGraph(strategies_constructor.leaf_strategies) cost_graph.simplify_graph() - solver = Solver(gm.graph, strategies_constructor, cost_graph, graph_analyser, memory_budget=1620017824.0) - # solver = Solver(gm.graph, strategies_constructor, cost_graph, graph_analyser) + solver = Solver(gm.graph, strategies_constructor, cost_graph, graph_analyser) ret = solver.call_solver_serialized_args() - print(ret) - strategies_list = list(ret[0]) - print(strategies_list) + print(ret[0]) + solver._recover_merged_node_strategy() + print(solver.last_s_val) + strategies_list = solver.last_s_val + computation_cost = 0 communication_cost = 0 communication_cost_bn = 0 @@ -102,10 +101,6 @@ def test_cost_graph(): for index, node in enumerate(graph.nodes): if node.op == 'call_module': submod = node.graph.owning_module.get_submodule(node.target) - if type(submod) in ELEMENTWISE_MODULE_OP: - input_spec = node.args[0].strategies_vector[strategies_list[index]].output_sharding_spec - print(node.name, input_spec) - continue if type(submod) in BATCHNORM_MODULE_OP: communication_cost_bn += node.strategies_vector[strategies_list[index]].communication_cost print(node.name, node.strategies_vector[strategies_list[index]].name)