mirror of https://github.com/hpcaitech/ColossalAI
[autoparallel] fix construct meta info. (#2245)
parent
89542ceb44
commit
b7d0990c61
|
@ -62,7 +62,8 @@ def construct_meta_info(node: Node, user_node: Node) -> MetaInfo:
|
||||||
return new_shape
|
return new_shape
|
||||||
|
|
||||||
meta_info = MetaInfo()
|
meta_info = MetaInfo()
|
||||||
origin_sharding_spec, target_sharding_spec = node.sharding_spec, user_node.sharding_spec
|
origin_sharding_spec, target_sharding_spec = node.sharding_spec, user_node.best_strategy.get_sharding_spec_by_name(
|
||||||
|
str(node.name))
|
||||||
_, comm_action_sequence, total_cost = shape_consistency_manager.shape_consistency(
|
_, comm_action_sequence, total_cost = shape_consistency_manager.shape_consistency(
|
||||||
origin_sharding_spec, target_sharding_spec)
|
origin_sharding_spec, target_sharding_spec)
|
||||||
|
|
||||||
|
@ -174,8 +175,8 @@ def _shape_consistency_apply(gm: torch.fx.GraphModule):
|
||||||
runtime_apply,
|
runtime_apply,
|
||||||
args=(node, origin_dict_node, input_dict_node,
|
args=(node, origin_dict_node, input_dict_node,
|
||||||
node_to_index_dict[node], user_node_index))
|
node_to_index_dict[node], user_node_index))
|
||||||
# meta_info = construct_meta_info(node, user_node)
|
meta_info = construct_meta_info(node, user_node)
|
||||||
# setattr(shape_consistency_node, 'best_metainfo', meta_info)
|
setattr(shape_consistency_node, 'best_metainfo', meta_info)
|
||||||
|
|
||||||
new_args = list(user_node.args)
|
new_args = list(user_node.args)
|
||||||
new_kwargs = dict(user_node.kwargs)
|
new_kwargs = dict(user_node.kwargs)
|
||||||
|
|
Loading…
Reference in New Issue