From b7d0990c61e9f6590e44330dfe89c92434d7a507 Mon Sep 17 00:00:00 2001 From: Super Daniel <78588128+super-dainiu@users.noreply.github.com> Date: Fri, 30 Dec 2022 19:56:44 +0800 Subject: [PATCH] [autoparallel] fix construct meta info. (#2245) --- colossalai/auto_parallel/passes/runtime_apply_pass.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/colossalai/auto_parallel/passes/runtime_apply_pass.py b/colossalai/auto_parallel/passes/runtime_apply_pass.py index df4a3fde7..5d224542c 100644 --- a/colossalai/auto_parallel/passes/runtime_apply_pass.py +++ b/colossalai/auto_parallel/passes/runtime_apply_pass.py @@ -62,7 +62,8 @@ def construct_meta_info(node: Node, user_node: Node) -> MetaInfo: return new_shape meta_info = MetaInfo() - origin_sharding_spec, target_sharding_spec = node.sharding_spec, user_node.sharding_spec + origin_sharding_spec, target_sharding_spec = node.sharding_spec, user_node.best_strategy.get_sharding_spec_by_name( + str(node.name)) _, comm_action_sequence, total_cost = shape_consistency_manager.shape_consistency( origin_sharding_spec, target_sharding_spec) @@ -174,8 +175,8 @@ def _shape_consistency_apply(gm: torch.fx.GraphModule): runtime_apply, args=(node, origin_dict_node, input_dict_node, node_to_index_dict[node], user_node_index)) - # meta_info = construct_meta_info(node, user_node) - # setattr(shape_consistency_node, 'best_metainfo', meta_info) + meta_info = construct_meta_info(node, user_node) + setattr(shape_consistency_node, 'best_metainfo', meta_info) new_args = list(user_node.args) new_kwargs = dict(user_node.kwargs)