[hotfix] add recompile after graph manipulatation (#1621)

pull/1617/head^2
YuliangLiu0306 2022-09-23 11:00:33 +08:00 committed by GitHub
parent d967779a32
commit 9eae855408
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 1 additions and 0 deletions

View File

@ -65,6 +65,7 @@ def check_apply(rank, world_size, port):
solution = list(ret[0])
sharding_spec_dict, origin_spec_dict = solution_annotatation_pass(gm, solution, device_mesh)
shape_consistency_pass(gm)
gm.recompile()
nodes = [node for node in gm.graph.nodes]
# TODO: wrap the gm to avoid the influence of the user training code
output = gm(input, sharding_spec_dict, origin_spec_dict)