Browse Source

[autoparallel] fix construct meta info. (#2245)

pull/2254/head
Super Daniel 2 years ago committed by GitHub
parent
commit
b7d0990c61
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 7
      colossalai/auto_parallel/passes/runtime_apply_pass.py

7
colossalai/auto_parallel/passes/runtime_apply_pass.py

@ -62,7 +62,8 @@ def construct_meta_info(node: Node, user_node: Node) -> MetaInfo:
return new_shape
meta_info = MetaInfo()
origin_sharding_spec, target_sharding_spec = node.sharding_spec, user_node.sharding_spec
origin_sharding_spec, target_sharding_spec = node.sharding_spec, user_node.best_strategy.get_sharding_spec_by_name(
str(node.name))
_, comm_action_sequence, total_cost = shape_consistency_manager.shape_consistency(
origin_sharding_spec, target_sharding_spec)
@ -174,8 +175,8 @@ def _shape_consistency_apply(gm: torch.fx.GraphModule):
runtime_apply,
args=(node, origin_dict_node, input_dict_node,
node_to_index_dict[node], user_node_index))
# meta_info = construct_meta_info(node, user_node)
# setattr(shape_consistency_node, 'best_metainfo', meta_info)
meta_info = construct_meta_info(node, user_node)
setattr(shape_consistency_node, 'best_metainfo', meta_info)
new_args = list(user_node.args)
new_kwargs = dict(user_node.kwargs)

Loading…
Cancel
Save