From 9eae8554081f8e0b7a2bc434ac91289a07f35dbe Mon Sep 17 00:00:00 2001 From: YuliangLiu0306 <72588413+YuliangLiu0306@users.noreply.github.com> Date: Fri, 23 Sep 2022 11:00:33 +0800 Subject: [PATCH] [hotfix] add recompile after graph manipulatation (#1621) --- tests/test_auto_parallel/test_shape_consistency_pass.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_auto_parallel/test_shape_consistency_pass.py b/tests/test_auto_parallel/test_shape_consistency_pass.py index 2a7b745f8..6cb46c1de 100644 --- a/tests/test_auto_parallel/test_shape_consistency_pass.py +++ b/tests/test_auto_parallel/test_shape_consistency_pass.py @@ -65,6 +65,7 @@ def check_apply(rank, world_size, port): solution = list(ret[0]) sharding_spec_dict, origin_spec_dict = solution_annotatation_pass(gm, solution, device_mesh) shape_consistency_pass(gm) + gm.recompile() nodes = [node for node in gm.graph.nodes] # TODO: wrap the gm to avoid the influence of the user training code output = gm(input, sharding_spec_dict, origin_spec_dict)